vak.train.frame_classification.train_frame_classification_model

vak.train.frame_classification.train_frame_classification_model(model_config: dict, dataset_config: dict, trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, checkpoint_path: str | Path | None = None, frames_standardizer_path: str | Path | None = None, results_path: str | Path | None = None, standardize_frames: bool = True, shuffle: bool = True, val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, subset: str | None = None) None[source]

Train a model from the frame classification family and save results.

Saves checkpoint files for model, label map, and spectrogram scaler. These are saved either in results_path if specified, or a new directory made inside root_results_dir.

Parameters:
  • model_config (dict) – Model configuration in a dict. Can be obtained by calling vak.config.ModelConfig.asdict().

  • dataset_config (dict) – Dataset configuration in a dict. Can be obtained by calling vak.config.DatasetConfig.asdict().

  • trainer_config (dict) – Configuration for lightning.pytorch.Trainer in a dict. Can be obtained by calling vak.config.TrainerConfig.asdict().

  • batch_size (int) – number of samples per batch presented to models during training.

  • num_epochs (int) – number of training epochs. One epoch = one iteration through the entire training set.

  • num_workers (int) – Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Optional, default is None.

  • checkpoint_path (str, pathlib.Path) – path to a checkpoint file, e.g., one generated by a previous run of vak.core.train. If specified, this checkpoint will be loaded into model. Used when continuing training. Default is None, in which case a new model is initialized.

  • frames_standardizer_path (str, pathlib.Path) – path to a saved FramesStandardizer used to standardize (normalize) frames, the input to a frame classification model. e.g., one generated by a previous run of vak.core.train(). Used when continuing training, for example on the same dataset. Default is None.

  • root_results_dir (str, pathlib.Path) – Root directory in which a new directory will be created where results will be saved.

  • results_path (str, pathlib.Path) – Directory where results will be saved. If specified, this parameter overrides root_results_dir.

  • shuffle (bool) – if True, shuffle training data before each epoch. Default is True.

  • standardize_frames (bool) – if True, use vak.transforms.FramesStandardizer to standardize the frames. Normalization is done by subtracting off the mean for each row of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data.

  • val_step (int) – Step on which to estimate accuracy using validation set. If val_step is n, then validation is carried out every time the global step / n is a whole number, i.e., when val_step modulo the global step is 0. Default is None, in which case no validation is done.

  • ckpt_step (int) – Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch.

  • patience (int) – number of validation steps to wait without performance on the validation set improving before stopping the training. Default is None, in which case training only stops after the specified number of epochs.

  • subset (str) – Name of a subset from the training split of the dataset to use when training model. This parameter is used by vak.learncurve.learncurve() to specify subsets when training models for a learning curve.