vak.config.train.TrainConfig

class vak.config.train.TrainConfig(model, num_epochs, batch_size, root_results_dir, dataset: DatasetConfig, trainer: TrainerConfig, results_dirname=None, standardize_frames=False, num_workers=2, shuffle=True, val_step=None, ckpt_step=None, patience=None, checkpoint_path=None, frames_standardizer_path=None)[source]

Bases: object

Class that represents [vak.train] table of configuration file.

model

The model to use: its name, and the parameters to configure it. Must be an instance of vak.config.ModelConfig

Type:

vak.config.ModelConfig

num_epochs

number of training epochs. One epoch = one iteration through the entire training set.

Type:

int

batch_size

number of samples per batch presented to models during training.

Type:

int

root_results_dir

directory in which results will be created. The vak.cli.train function will create a subdirectory in this directory each time it runs.

Type:

str

dataset

The dataset to use: the path to it, and optionally a path to a file representing splits, and the name, if it is a built-in dataset. Must be an instance of vak.config.DatasetConfig.

Type:

vak.config.DatasetConfig

trainer

Configuration for lightning.pytorch.Trainer. Must be an instance of vak.config.TrainerConfig.

Type:

vak.config.TrainerConfig

num_workers

Number of processes to use for parallel loading of data. Argument to torch.DataLoader.

Type:

int

shuffle

if True, shuffle training data before each epoch. Default is True.

Type:

bool

standardize_frames

if True, use vak.transforms.FramesStandardizer to standardize the frames. Normalization is done by subtracting off the mean for each row of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data.

Type:

bool

val_step

Step on which to estimate accuracy using validation set. If val_step is n, then validation is carried out every time the global step / n is a whole number, i.e., when val_step modulo the global step is 0. Default is None, in which case no validation is done.

Type:

int

ckpt_step

Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch.

Type:

int

patience

number of validation steps to wait without performance on the validation set improving before stopping the training. Default is None, in which case training only stops after the specified number of epochs.

Type:

int

checkpoint_path

path to directory with checkpoint files saved by Torch, to reload model. Default is None, in which case a new model is initialized.

Type:

str

frames_standardizer_path

path to a saved vak.transforms.FramesStandardizer object used to standardize (normalize) frames. If spectrograms were normalized and this is not provided, will give incorrect results. Default is None.

Type:

str

__init__(model, num_epochs, batch_size, root_results_dir, dataset: DatasetConfig, trainer: TrainerConfig, results_dirname=None, standardize_frames=False, num_workers=2, shuffle=True, val_step=None, ckpt_step=None, patience=None, checkpoint_path=None, frames_standardizer_path=None) None

Method generated by attrs for class TrainConfig.

Methods

__init__(model, num_epochs, batch_size, ...)

Method generated by attrs for class TrainConfig.

from_config_dict(config_dict)

Return TrainConfig instance from a dict.

Attributes

classmethod from_config_dict(config_dict: dict) TrainConfig[source]

Return TrainConfig instance from a dict.

The dict passed in should be the one found by loading a valid configuration toml file with vak.config.parse.from_toml_path(), and then using key train, i.e., TrainConfig.from_config_dict(config_dict['train']).