vak.train.frame_classification.train_frame_classification_model#

vak.train.frame_classification.train_frame_classification_model(model_name: str, model_config: dict, dataset_path: str | Path, batch_size: int, num_epochs: int, num_workers: int, train_transform_params: dict | None = None, train_dataset_params: dict | None = None, val_transform_params: dict | None = None, val_dataset_params: dict | None = None, checkpoint_path: str | Path | None = None, spect_scaler_path: str | Path | None = None, results_path: str | Path | None = None, normalize_spectrograms: bool = True, shuffle: bool = True, val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, device: str | None = None, subset: str | None = None) None[source]#

Train a model from the frame classification family and save results.

Saves checkpoint files for model, label map, and spectrogram scaler. These are saved either in results_path if specified, or a new directory made inside root_results_dir.

Parameters:
  • model_name (str) – Model name, must be one of vak.models.registry.MODEL_NAMES.

  • model_config (dict) – Model configuration in a dict, as loaded from a .toml file, and used by the model method from_config.

  • dataset_path (str) – Path to dataset, a directory generated by running vak prep.

  • batch_size (int) – number of samples per batch presented to models during training.

  • num_epochs (int) – number of training epochs. One epoch = one iteration through the entire training set.

  • num_workers (int) – Number of processes to use for parallel loading of data. Argument to torch.DataLoader.

  • train_transform_params (dict, optional) – Parameters for training data transform. Passed as keyword arguments. Optional, default is None.

  • train_dataset_params (dict, optional) – Parameters for training dataset. Passed as keyword arguments to vak.datasets.frame_classification.WindowDataset. Optional, default is None.

  • val_transform_params (dict, optional) – Parameters for validation data transform. Passed as keyword arguments. Optional, default is None.

  • val_dataset_params (dict, optional) – Parameters for validation dataset. Passed as keyword arguments to vak.datasets.frame_classification.FramesDataset. Optional, default is None.

  • dataset_csv_path – Path to csv file representing splits of dataset, e.g., such a file generated by running vak prep. This parameter is used by vak.core.learncurve() to specify different splits to use, when generating results for a learning curve. If this argument is specified, the csv file must be inside the directory dataset_path.

  • checkpoint_path (str, pathlib.Path) – path to a checkpoint file, e.g., one generated by a previous run of vak.core.train. If specified, this checkpoint will be loaded into model. Used when continuing training. Default is None, in which case a new model is initialized.

  • spect_scaler_path (str, pathlib.Path) – path to a SpectScaler used to normalize spectrograms, e.g., one generated by a previous run of vak.core.train. Used when continuing training, for example on the same dataset. Default is None.

  • root_results_dir (str, pathlib.Path) – Root directory in which a new directory will be created where results will be saved.

  • results_path (str, pathlib.Path) – Directory where results will be saved. If specified, this parameter overrides root_results_dir.

  • device (str) – Device on which to work with model + data. Default is None. If None, then a device will be selected with vak.split.get_default. That function defaults to ‘cuda’ if torch.cuda.is_available is True.

  • shuffle (bool) – if True, shuffle training data before each epoch. Default is True.

  • normalize_spectrograms (bool) – if True, use spect.utils.data.SpectScaler to normalize the spectrograms. Normalization is done by subtracting off the mean for each frequency bin of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data.

  • val_step (int) – Step on which to estimate accuracy using validation set. If val_step is n, then validation is carried out every time the global step / n is a whole number, i.e., when val_step modulo the global step is 0. Default is None, in which case no validation is done.

  • ckpt_step (int) – Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch.

  • patience (int) – number of validation steps to wait without performance on the validation set improving before stopping the training. Default is None, in which case training only stops after the specified number of epochs.

  • subset (str) – Name of a subset from the training split of the dataset to use when training model. This parameter is used by vak.learncurve.learncurve() to specify subsets when training models for a learning curve.