Source code for vak.train.parametric_umap

"""Function that trains models in the Parametric UMAP family."""

from __future__ import annotations

import datetime
import logging
import pathlib

import pandas as pd
import lightning
import torch.utils.data

from .. import datasets, models, transforms
from ..common import validators
from ..common.paths import generate_results_dir_name_as_path
from ..datasets.parametric_umap import ParametricUMAPDataset

logger = logging.getLogger(__name__)


[docs] def get_split_dur(df: pd.DataFrame, split: str) -> float: """Get duration of a split in a dataset from a pandas DataFrame representing the dataset.""" return df[df["split"] == split]["duration"].sum()
[docs] def get_trainer( accelerator: str, devices: int | list[int], max_epochs: int, ckpt_root: str | pathlib.Path, ckpt_step: int, log_save_dir: str | pathlib.Path, ) -> lightning.pytorch.Trainer: """Returns an instance of ``lightning.pytorch.Trainer`` with a default set of callbacks. Used by ``vak.core`` functions.""" ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( dirpath=ckpt_root, filename="checkpoint", every_n_train_steps=ckpt_step, save_last=True, verbose=True, ) ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" ckpt_callback.FILE_EXTENSION = ".pt" val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( monitor="val_loss", dirpath=ckpt_root, save_top_k=1, mode="min", filename="min-val-loss-checkpoint", auto_insert_metric_name=False, verbose=True, ) val_ckpt_callback.FILE_EXTENSION = ".pt" callbacks = [ ckpt_callback, val_ckpt_callback, ] logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) trainer = lightning.pytorch.Trainer( max_epochs=max_epochs, accelerator=accelerator, devices=devices, logger=logger, callbacks=callbacks, ) return trainer
[docs] def train_parametric_umap_model( model_config: dict, dataset_config: dict, trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, checkpoint_path: str | pathlib.Path | None = None, root_results_dir: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, shuffle: bool = True, val_step: int | None = None, ckpt_step: int | None = None, subset: str | None = None, ) -> None: """Train a model from the parametric UMAP family and save results. Saves checkpoint files for model, label map, and spectrogram scaler. These are saved either in ``results_path`` if specified, or a new directory made inside ``root_results_dir``. Parameters ---------- model_config : dict Model configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. trainer_config: dict Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`. Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. checkpoint_path : str, pathlib.Path, optional path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. If specified, this checkpoint will be loaded into model. Used when continuing training. Default is None, in which case a new model is initialized. root_results_dir : str, pathlib.Path, optional Root directory in which a new directory will be created where results will be saved. results_path : str, pathlib.Path, optional Directory where results will be saved. If specified, this parameter overrides ``root_results_dir``. val_step : int Computes the loss using validation set every ``val_step`` epochs. Default is None, in which case no validation is done. ckpt_step : int Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch. shuffle: bool if True, shuffle training data before each epoch. Default is True. split : str Name of split from dataset found at ``dataset_path`` to use when training model. Default is 'train'. This parameter is used by `vak.learncurve.learncurve` to specify specific subsets of the training set to use when training models for a learning curve. """ for path, path_name in zip( (checkpoint_path,), ("checkpoint_path",), ): if path is not None: if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( f"Loading dataset from path: {dataset_path}", ) metadata = datasets.parametric_umap.Metadata.from_dataset_path( dataset_path ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) # ---------------- pre-conditions ---------------------------------------------------------------------------------- if val_step and not dataset_df["split"].str.contains("val").any(): raise ValueError( f"val_step set to {val_step} but dataset does not contain a validation set; " f"please run `vak prep` with a config.toml file that specifies a duration for the validation set." ) # ---- set up directory to save output ----------------------------------------------------------------------------- if results_path: results_path = pathlib.Path(results_path).expanduser().resolve() if not results_path.is_dir(): raise NotADirectoryError( f"results_path not recognized as a directory: {results_path}" ) else: results_path = generate_results_dir_name_as_path(root_results_dir) results_path.mkdir() # ---------------- load training data ----------------------------------------------------------------------------- logger.info(f"using training dataset from {dataset_path}") # below, if we're going to train network to predict unlabeled segments, then # we need to include a class for those unlabeled segments in labelmap, # the mapping from labelset provided by user to a set of consecutive # integers that the network learns to predict train_dur = get_split_dur(dataset_df, "train") logger.info( f"Total duration of training split from dataset (in s): {train_dur}", ) model_name = model_config["name"] train_transform = transforms.defaults.get_default_transform( model_name, "train" ) dataset_params = dataset_config["params"] train_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, transform=train_transform, **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for training, in seconds: {train_dataset.duration}", ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, ) # ---------------- load validation set (if there is one) ----------------------------------------------------------- if val_step: transform = transforms.defaults.get_default_transform( model_name, "eval" ) val_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="val", transform=transform, **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}", ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, ) else: val_loader = None model = models.get( model_name, model_config, input_shape=train_dataset.shape, ) if checkpoint_path is not None: logger.info( f"loading checkpoint for {model_name} from path: {checkpoint_path}", ) model.load_state_dict_from_path(checkpoint_path) results_model_root = results_path.joinpath(model_name) results_model_root.mkdir() ckpt_root = results_model_root.joinpath("checkpoints") ckpt_root.mkdir() logger.info(f"training {model_name}") trainer = get_trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], max_epochs=num_epochs, log_save_dir=results_model_root, ckpt_root=ckpt_root, ckpt_step=ckpt_step, ) train_time_start = datetime.datetime.now() logger.info(f"Training start time: {train_time_start.isoformat()}") trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, ) train_time_stop = datetime.datetime.now() logger.info(f"Training stop time: {train_time_stop.isoformat()}") elapsed = train_time_stop - train_time_start logger.info(f"Elapsed training time: {elapsed}")