Source code for vak.cli.train

import logging
import shutil
from pathlib import Path

from .. import config
from .. import train as train_module
from ..common.logging import config_logging_for_cli, log_version
from ..common.paths import generate_results_dir_name_as_path

logger = logging.getLogger(__name__)


[docs] def train(toml_path): """train models using training set specified in config.toml file. Function called by command-line interface. Trains models, saves results in new directory within root_results_dir specified in config.toml file, and adds path to that new directory to config.toml file. Parameters ---------- toml_path : str, Path path to a configuration file in TOML format. """ toml_path = Path(toml_path) cfg = config.Config.from_toml_path(toml_path) if cfg.train is None: raise ValueError( f"train called with a config.toml file that does not have a TRAIN section: {toml_path}" ) # ---- set up directory to save output ----------------------------------------------------------------------------- results_path = generate_results_dir_name_as_path( cfg.train.root_results_dir ) results_path.mkdir(parents=True) # copy config file into results dir now that we've made the dir shutil.copy(toml_path, results_path) # ---- set up logging ---------------------------------------------------------------------------------------------- config_logging_for_cli( log_dst=results_path, log_stem="train", level="INFO", force=True ) log_version(logger) logger.info("Logging results to {}".format(results_path)) if cfg.train.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " f"please run the following command:\n'vak prep {toml_path}'" ) train_module.train( model_config=cfg.train.model.asdict(), dataset_config=cfg.train.dataset.asdict(), trainer_config=cfg.train.trainer.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, checkpoint_path=cfg.train.checkpoint_path, frames_standardizer_path=cfg.train.frames_standardizer_path, results_path=results_path, standardize_frames=cfg.train.standardize_frames, shuffle=cfg.train.shuffle, val_step=cfg.train.val_step, ckpt_step=cfg.train.ckpt_step, patience=cfg.train.patience, )