"""High-level function that trains models."""
from __future__ import annotations
import logging
import pathlib
from .. import models
from ..common import validators
from .frame_classification import train_frame_classification_model
from .parametric_umap import train_parametric_umap_model
logger = logging.getLogger(__name__)
[docs]
def train(
model_config: dict,
dataset_config: dict,
trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
checkpoint_path: str | pathlib.Path | None = None,
spect_scaler_path: str | pathlib.Path | None = None,
results_path: str | pathlib.Path | None = None,
normalize_spectrograms: bool = True,
shuffle: bool = True,
val_step: int | None = None,
ckpt_step: int | None = None,
patience: int | None = None,
subset: str | None = None,
):
"""Train a model and save results.
Saves checkpoint files for model,
label map, and spectrogram scaler.
These are saved in ``results_path``.
Parameters
----------
model_config : dict
Model configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`.
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
trainer_config: dict
Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
window_size : int
size of windows taken from spectrograms, in number of time bins,
shown to neural networks
batch_size : int
number of samples per batch presented to models during training.
num_epochs : int
number of training epochs. One epoch = one iteration through the entire
training set.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader.
checkpoint_path : str, pathlib.Path
Path to a checkpoint file,
e.g., one generated by a previous run of ``vak.core.train``.
If specified, this checkpoint will be loaded into model.
Used when continuing training.
Default is None, in which case a new model is initialized.
spect_scaler_path : str, pathlib.Path
path to a ``SpectScaler`` used to normalize spectrograms,
e.g., one generated by a previous run of ``vak.core.train``.
Used when continuing training, for example on the same dataset.
Default is None.
results_path : str, pathlib.Path
Directory where results will be saved.
spect_key : str
key for accessing spectrogram in files. Default is 's'.
timebins_key : str
key for accessing vector of time bins in files. Default is 't'.
device : str
Device on which to work with model + data.
Default is None. If None, then a device will be selected with vak.split.get_default.
That function defaults to 'cuda' if torch.cuda.is_available is True.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
if True, use spect.utils.data.SpectScaler to normalize the spectrograms.
Normalization is done by subtracting off the mean for each frequency bin
of the training set and then dividing by the std for that frequency bin.
This same normalization is then applied to validation + test data.
source_ids : numpy.ndarray
Parameter for WindowDataset. Represents the 'id' of any spectrogram,
i.e., the index into spect_paths that will let us load it.
Default is None.
source_inds : numpy.ndarray
Parameter for WindowDataset. Same length as source_ids
but values represent indices within each spectrogram.
Default is None.
window_inds : numpy.ndarray
Parameter for WindowDataset.
Indices of each window in the dataset. The value at x[0]
represents the start index of the first window; using that
value, we can index into source_ids to get the path
of the spectrogram file to load, and we can index into
source_inds to index into the spectrogram itself
and get the window.
Default is None.
val_step : int
Step on which to estimate accuracy using validation set.
If val_step is n, then validation is carried out every time
the global step / n is a whole number, i.e., when val_step modulo the global step is 0.
Default is None, in which case no validation is done.
ckpt_step : int
Step on which to save to checkpoint file.
If ckpt_step is n, then a checkpoint is saved every time
the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0.
Default is None, in which case checkpoint is only saved at the last epoch.
patience : int
number of validation steps to wait without performance on the
validation set improving before stopping the training.
Default is None, in which case training only stops after the specified number of epochs.
split : str
Name of split from dataset found at ``dataset_path`` to use
when training model. Default is 'train'. This parameter is used by
`vak.learncurve.learncurve` to specify specific subsets of the
training set to use when training models for a learning curve.
"""
for path, path_name in zip(
(checkpoint_path, spect_scaler_path),
("checkpoint_path", "spect_scaler_path"),
):
if path is not None:
if not validators.is_a_file(path):
raise FileNotFoundError(
f"value for ``{path_name}`` not recognized as a file: {path}"
)
dataset_path = pathlib.Path(dataset_config["path"])
if not dataset_path.exists() or not dataset_path.is_dir():
raise NotADirectoryError(
f"`dataset_path` not found or not recognized as a directory: {dataset_path}"
)
model_name = model_config["name"]
try:
model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name]
except KeyError as e:
raise ValueError(
f"No model family found for the model name specified: {model_name}"
) from e
if model_family == "FrameClassificationModel":
train_frame_classification_model(
model_config=model_config,
dataset_config=dataset_config,
trainer_config=trainer_config,
batch_size=batch_size,
num_epochs=num_epochs,
num_workers=num_workers,
checkpoint_path=checkpoint_path,
spect_scaler_path=spect_scaler_path,
results_path=results_path,
normalize_spectrograms=normalize_spectrograms,
shuffle=shuffle,
val_step=val_step,
ckpt_step=ckpt_step,
patience=patience,
subset=subset,
)
elif model_family == "ParametricUMAPModel":
train_parametric_umap_model(
model_config=model_config,
dataset_config=dataset_config,
trainer_config=trainer_config,
batch_size=batch_size,
num_epochs=num_epochs,
num_workers=num_workers,
checkpoint_path=checkpoint_path,
results_path=results_path,
shuffle=shuffle,
val_step=val_step,
ckpt_step=ckpt_step,
subset=subset,
)
else:
raise ValueError(f"Model family not recognized: {model_family}")