Source code for vak.train.frame_classification

"""Function that trains models in the frame classification family."""

from __future__ import annotations

import datetime
import json
import logging
import pathlib
import shutil

import joblib
import pandas as pd
import torch.utils.data

from .. import datasets, models, transforms
from ..common import validators
from ..common.trainer import get_default_trainer
from ..datasets.frame_classification import FramesDataset, WindowDataset

logger = logging.getLogger(__name__)


[docs] def get_split_dur(df: pd.DataFrame, split: str) -> float: """Get duration of a split in a dataset from a pandas DataFrame representing the dataset.""" return df[df["split"] == split]["duration"].sum()
[docs] def train_frame_classification_model( model_config: dict, dataset_config: dict, trainer_config: dict, batch_size: int, num_epochs: int, num_workers: int, checkpoint_path: str | pathlib.Path | None = None, spect_scaler_path: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, normalize_spectrograms: bool = True, shuffle: bool = True, val_step: int | None = None, ckpt_step: int | None = None, patience: int | None = None, subset: str | None = None, ) -> None: """Train a model from the frame classification family and save results. Saves checkpoint files for model, label map, and spectrogram scaler. These are saved either in ``results_path`` if specified, or a new directory made inside ``root_results_dir``. Parameters ---------- model_config : dict Model configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. trainer_config: dict Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`. Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Optional, default is None. checkpoint_path : str, pathlib.Path path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. If specified, this checkpoint will be loaded into model. Used when continuing training. Default is None, in which case a new model is initialized. spect_scaler_path : str, pathlib.Path path to a ``SpectScaler`` used to normalize spectrograms, e.g., one generated by a previous run of ``vak.core.train``. Used when continuing training, for example on the same dataset. Default is None. root_results_dir : str, pathlib.Path Root directory in which a new directory will be created where results will be saved. results_path : str, pathlib.Path Directory where results will be saved. If specified, this parameter overrides ``root_results_dir``. shuffle: bool if True, shuffle training data before each epoch. Default is True. normalize_spectrograms : bool if True, use spect.utils.data.SpectScaler to normalize the spectrograms. Normalization is done by subtracting off the mean for each frequency bin of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data. val_step : int Step on which to estimate accuracy using validation set. If val_step is n, then validation is carried out every time the global step / n is a whole number, i.e., when val_step modulo the global step is 0. Default is None, in which case no validation is done. ckpt_step : int Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch. patience : int number of validation steps to wait without performance on the validation set improving before stopping the training. Default is None, in which case training only stops after the specified number of epochs. subset : str Name of a subset from the training split of the dataset to use when training model. This parameter is used by :func:`vak.learncurve.learncurve` to specify subsets when training models for a learning curve. """ for path, path_name in zip( (checkpoint_path, spect_scaler_path), ("checkpoint_path", "spect_scaler_path"), ): if path is not None: if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( f"Loading dataset from `dataset_path`: {dataset_path}", ) metadata = datasets.frame_classification.Metadata.from_dataset_path( dataset_path ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) # ---------------- pre-conditions ---------------------------------------------------------------------------------- if val_step and not dataset_df["split"].str.contains("val").any(): raise ValueError( f"val_step set to {val_step} but dataset does not contain a validation set; " f"please run `vak prep` with a config.toml file that specifies a duration for the validation set." ) # ---- set up directory to save output ----------------------------------------------------------------------------- results_path = pathlib.Path(results_path).expanduser().resolve() if not results_path.is_dir(): raise NotADirectoryError( f"results_path not recognized as a directory: {results_path}" ) frame_dur = metadata.frame_dur logger.info( f"Duration of a frame in dataset, in seconds: {frame_dur}", ) # ---------------- load training data ----------------------------------------------------------------------------- logger.info(f"Using training split from dataset: {dataset_path}") # below, if we're going to train network to predict unlabeled segments, then # we need to include a class for those unlabeled segments in labelmap, # the mapping from labelset provided by user to a set of consecutive # integers that the network learns to predict train_dur = get_split_dur(dataset_df, "train") logger.info( f"Total duration of training split from dataset (in s): {train_dur}", ) labelmap_path = dataset_path / "labelmap.json" logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) # copy to new results_path with open(results_path.joinpath("labelmap.json"), "w") as f: json.dump(labelmap, f) if spect_scaler_path is not None and normalize_spectrograms: logger.info(f"loading spect scaler from path: {spect_scaler_path}") spect_standardizer = joblib.load(spect_scaler_path) shutil.copy(spect_scaler_path, results_path) # get transforms just before creating datasets with them elif normalize_spectrograms and spect_scaler_path is None: logger.info( "no spect_scaler_path provided, not loading", ) logger.info("will normalize spectrograms") spect_standardizer = transforms.StandardizeSpect.fit_dataset_path( dataset_path, split="train", subset=subset, ) joblib.dump( spect_standardizer, results_path.joinpath("StandardizeSpect") ) elif spect_scaler_path is not None and not normalize_spectrograms: raise ValueError( "spect_scaler_path provided but normalize_spectrograms was False, these options conflict" ) else: # not normalize_spectrograms and spect_scaler_path is None: logger.info( "normalize_spectrograms is False and no spect_scaler_path was provided, " "will not standardize spectrograms", ) spect_standardizer = None model_name = model_config["name"] # TODO: move this into datapipe once each datapipe uses a fixed set of transforms # that will require adding `spect_standardizer`` as a parameter to the datapipe, # maybe rename to `frames_standardizer`? try: window_size = dataset_config["params"]["window_size"] except KeyError as e: raise KeyError( f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" ) from e transform_kwargs = { "spect_standardizer": spect_standardizer, "window_size": window_size, } train_transform = transforms.defaults.get_default_transform( model_name, "train", transform_kwargs=transform_kwargs ) train_dataset = WindowDataset.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, item_transform=train_transform, **dataset_config["params"], ) logger.info( f"Duration of WindowDataset used for training, in seconds: {train_dataset.duration}", ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, ) # ---------------- load validation set (if there is one) ----------------------------------------------------------- if val_step: logger.info( f"Will measure error on validation set every {val_step} steps of training", ) logger.info(f"Using validation split from dataset:\n{dataset_path}") val_dur = get_split_dur(dataset_df, "val") logger.info( f"Total duration of validation split from dataset (in s): {val_dur}", ) # NOTE: we use same `transform_kwargs` here; will need to change to a `dataset_param` # when we factor transform *into* fixed DataPipes as above val_transform = transforms.defaults.get_default_transform( model_name, "eval", transform_kwargs ) val_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split="val", item_transform=val_transform, ) logger.info( f"Duration of FramesDataset used for evaluation, in seconds: {val_dataset.duration}", ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, shuffle=False, # batch size 1 because each spectrogram reshaped into a batch of windows batch_size=1, num_workers=num_workers, ) else: val_loader = None model = models.get( model_name, model_config, num_classes=len(labelmap), input_shape=train_dataset.shape, labelmap=labelmap, ) if checkpoint_path is not None: logger.info( f"loading checkpoint for {model_name} from path: {checkpoint_path}", ) model.load_state_dict_from_path(checkpoint_path) results_model_root = results_path.joinpath(model_name) results_model_root.mkdir() ckpt_root = results_model_root.joinpath("checkpoints") ckpt_root.mkdir() logger.info(f"training {model_name}") max_steps = num_epochs * len(train_loader) default_callback_kwargs = { "ckpt_root": ckpt_root, "ckpt_step": ckpt_step, "patience": patience, } trainer = get_default_trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], max_steps=max_steps, log_save_dir=results_model_root, val_step=val_step, default_callback_kwargs=default_callback_kwargs, ) train_time_start = datetime.datetime.now() logger.info(f"Training start time: {train_time_start.isoformat()}") trainer.fit( model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, ) train_time_stop = datetime.datetime.now() logger.info(f"Training stop time: {train_time_stop.isoformat()}") elapsed = train_time_stop - train_time_start logger.info(f"Elapsed training time: {elapsed}")