Source code for vak.train.parametric_umap

"""Function that trains models in the Parametric UMAP family."""

from __future__ import annotations

import datetime
import logging
import pathlib

import lightning
import pandas as pd
import torch.utils.data

from .. import datapipes, models
from ..common import validators
from ..common.paths import generate_results_dir_name_as_path
from ..datapipes.parametric_umap import Datapipe

logger = logging.getLogger(__name__)


[docs] def get_split_dur(df: pd.DataFrame, split: str) -> float: """Get duration of a split in a dataset from a pandas DataFrame representing the dataset.""" return df[df["split"] == split]["duration"].sum()
[docs] def get_trainer( accelerator: str, devices: int | list[int], max_epochs: int, ckpt_root: str | pathlib.Path, ckpt_step: int, log_save_dir: str | pathlib.Path, ) -> lightning.pytorch.Trainer: """Returns an instance of ``lightning.pytorch.Trainer`` with a default set of callbacks. Used by ``vak.core`` functions.""" ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( dirpath=ckpt_root, filename="checkpoint", every_n_train_steps=ckpt_step, save_last=True, verbose=True, ) ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" ckpt_callback.FILE_EXTENSION = ".pt" val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( monitor="val_loss", dirpath=ckpt_root, save_top_k=1, mode="min", filename="min-val-loss-checkpoint", auto_insert_metric_name=False, verbose=True, ) val_ckpt_callback.FILE_EXTENSION = ".pt" callbacks = [ ckpt_callback, val_ckpt_callback, ] logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) trainer = lightning.pytorch.Trainer( max_epochs=max_epochs, accelerator=accelerator, devices=devices, logger=logger, callbacks=callbacks, ) return trainer
[docs] def train_parametric_umap_model( model_config: dict, dataset_config: dict, trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, checkpoint_path: str | pathlib.Path | None = None, root_results_dir: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, shuffle: bool = True, val_step: int | None = None, ckpt_step: int | None = None, subset: str | None = None, ) -> None: """Train a model from the parametric UMAP family and save results. Saves checkpoint files for model, label map, and spectrogram scaler. These are saved either in ``results_path`` if specified, or a new directory made inside ``root_results_dir``. Parameters ---------- model_config : dict Model configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. trainer_config: dict Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`. Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. checkpoint_path : str, pathlib.Path, optional path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. If specified, this checkpoint will be loaded into model. Used when continuing training. Default is None, in which case a new model is initialized. root_results_dir : str, pathlib.Path, optional Root directory in which a new directory will be created where results will be saved. results_path : str, pathlib.Path, optional Directory where results will be saved. If specified, this parameter overrides ``root_results_dir``. val_step : int Computes the loss using validation set every ``val_step`` epochs. Default is None, in which case no validation is done. ckpt_step : int Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch. shuffle: bool if True, shuffle training data before each epoch. Default is True. split : str Name of split from dataset found at ``dataset_path`` to use when training model. Default is 'train'. This parameter is used by `vak.learncurve.learncurve` to specify specific subsets of the training set to use when training models for a learning curve. """ for path, path_name in zip( (checkpoint_path,), ("checkpoint_path",), ): if path is not None: if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( f"Loading dataset from path: {dataset_path}", ) metadata = datapipes.parametric_umap.Metadata.from_dataset_path( dataset_path ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) # ---------------- pre-conditions ---------------------------------------------------------------------------------- if val_step and not dataset_df["split"].str.contains("val").any(): raise ValueError( f"val_step set to {val_step} but dataset does not contain a validation set; " f"please run `vak prep` with a config.toml file that specifies a duration for the validation set." ) # ---- set up directory to save output ----------------------------------------------------------------------------- if results_path: results_path = pathlib.Path(results_path).expanduser().resolve() if not results_path.is_dir(): raise NotADirectoryError( f"results_path not recognized as a directory: {results_path}" ) else: results_path = generate_results_dir_name_as_path(root_results_dir) results_path.mkdir() # ---------------- load training data ----------------------------------------------------------------------------- logger.info(f"using training dataset from {dataset_path}") # below, if we're going to train network to predict unlabeled segments, then # we need to include a class for those unlabeled segments in labelmap, # the mapping from labelset provided by user to a set of consecutive # integers that the network learns to predict train_dur = get_split_dur(dataset_df, "train") logger.info( f"Total duration of training split from dataset (in s): {train_dur}", ) dataset_params = dataset_config["params"] train_dataset = Datapipe.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for training, in seconds: {train_dataset.duration}", ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, ) # ---------------- load validation set (if there is one) ----------------------------------------------------------- if val_step: val_dataset = Datapipe.from_dataset_path( dataset_path=dataset_path, split="val", **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}", ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, ) else: val_loader = None model_name = model_config["name"] model = models.get( model_name, model_config, input_shape=train_dataset.shape, ) if checkpoint_path is not None: logger.info( f"loading checkpoint for {model_name} from path: {checkpoint_path}", ) model.load_state_dict_from_path(checkpoint_path) results_model_root = results_path.joinpath(model_name) results_model_root.mkdir() ckpt_root = results_model_root.joinpath("checkpoints") ckpt_root.mkdir() logger.info(f"training {model_name}") trainer = get_trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], max_epochs=num_epochs, log_save_dir=results_model_root, ckpt_root=ckpt_root, ckpt_step=ckpt_step, ) train_time_start = datetime.datetime.now() logger.info(f"Training start time: {train_time_start.isoformat()}") trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, ) train_time_stop = datetime.datetime.now() logger.info(f"Training stop time: {train_time_stop.isoformat()}") elapsed = train_time_stop - train_time_start logger.info(f"Elapsed training time: {elapsed}")