"""Function that trains models in the Parametric UMAP family."""
from __future__ import annotations
import datetime
import logging
import pathlib
import pandas as pd
import lightning
import torch.utils.data
from .. import datasets, models, transforms
from ..common import validators
from ..common.paths import generate_results_dir_name_as_path
from ..datasets.parametric_umap import ParametricUMAPDataset
logger = logging.getLogger(__name__)
[docs]
def get_split_dur(df: pd.DataFrame, split: str) -> float:
"""Get duration of a split in a dataset from a pandas DataFrame representing the dataset."""
return df[df["split"] == split]["duration"].sum()
[docs]
def get_trainer(
accelerator: str,
devices: int | list[int],
max_epochs: int,
ckpt_root: str | pathlib.Path,
ckpt_step: int,
log_save_dir: str | pathlib.Path,
) -> lightning.pytorch.Trainer:
"""Returns an instance of ``lightning.pytorch.Trainer``
with a default set of callbacks.
Used by ``vak.core`` functions."""
ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
save_last=True,
verbose=True,
)
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"
val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor="val_loss",
dirpath=ckpt_root,
save_top_k=1,
mode="min",
filename="min-val-loss-checkpoint",
auto_insert_metric_name=False,
verbose=True,
)
val_ckpt_callback.FILE_EXTENSION = ".pt"
callbacks = [
ckpt_callback,
val_ckpt_callback,
]
logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)
trainer = lightning.pytorch.Trainer(
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
logger=logger,
callbacks=callbacks,
)
return trainer
[docs]
def train_parametric_umap_model(
model_config: dict,
dataset_config: dict,
trainer_config: dict,
batch_size: int,
num_epochs: int,
num_workers: int,
checkpoint_path: str | pathlib.Path | None = None,
root_results_dir: str | pathlib.Path | None = None,
results_path: str | pathlib.Path | None = None,
shuffle: bool = True,
val_step: int | None = None,
ckpt_step: int | None = None,
subset: str | None = None,
) -> None:
"""Train a model from the parametric UMAP family
and save results.
Saves checkpoint files for model,
label map, and spectrogram scaler.
These are saved either in ``results_path``
if specified, or a new directory
made inside ``root_results_dir``.
Parameters
----------
model_config : dict
Model configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`.
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
trainer_config: dict
Configuration for :class:`lightning.pytorch.Trainer` in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
batch_size : int
number of samples per batch presented to models during training.
num_epochs : int
number of training epochs. One epoch = one iteration through the entire
training set.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader.
checkpoint_path : str, pathlib.Path, optional
path to a checkpoint file,
e.g., one generated by a previous run of ``vak.core.train``.
If specified, this checkpoint will be loaded into model.
Used when continuing training.
Default is None, in which case a new model is initialized.
root_results_dir : str, pathlib.Path, optional
Root directory in which a new directory will be created
where results will be saved.
results_path : str, pathlib.Path, optional
Directory where results will be saved.
If specified, this parameter overrides ``root_results_dir``.
val_step : int
Computes the loss using validation set every ``val_step`` epochs.
Default is None, in which case no validation is done.
ckpt_step : int
Step on which to save to checkpoint file.
If ckpt_step is n, then a checkpoint is saved every time
the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0.
Default is None, in which case checkpoint is only saved at the last epoch.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
split : str
Name of split from dataset found at ``dataset_path`` to use
when training model. Default is 'train'. This parameter is used by
`vak.learncurve.learncurve` to specify specific subsets of the
training set to use when training models for a learning curve.
"""
for path, path_name in zip(
(checkpoint_path,),
("checkpoint_path",),
):
if path is not None:
if not validators.is_a_file(path):
raise FileNotFoundError(
f"value for ``{path_name}`` not recognized as a file: {path}"
)
dataset_path = pathlib.Path(dataset_config["path"])
if not dataset_path.exists() or not dataset_path.is_dir():
raise NotADirectoryError(
f"`dataset_path` not found or not recognized as a directory: {dataset_path}"
)
logger.info(
f"Loading dataset from path: {dataset_path}",
)
metadata = datasets.parametric_umap.Metadata.from_dataset_path(
dataset_path
)
dataset_csv_path = dataset_path / metadata.dataset_csv_filename
dataset_df = pd.read_csv(dataset_csv_path)
# ---------------- pre-conditions ----------------------------------------------------------------------------------
if val_step and not dataset_df["split"].str.contains("val").any():
raise ValueError(
f"val_step set to {val_step} but dataset does not contain a validation set; "
f"please run `vak prep` with a config.toml file that specifies a duration for the validation set."
)
# ---- set up directory to save output -----------------------------------------------------------------------------
if results_path:
results_path = pathlib.Path(results_path).expanduser().resolve()
if not results_path.is_dir():
raise NotADirectoryError(
f"results_path not recognized as a directory: {results_path}"
)
else:
results_path = generate_results_dir_name_as_path(root_results_dir)
results_path.mkdir()
# ---------------- load training data -----------------------------------------------------------------------------
logger.info(f"using training dataset from {dataset_path}")
# below, if we're going to train network to predict unlabeled segments, then
# we need to include a class for those unlabeled segments in labelmap,
# the mapping from labelset provided by user to a set of consecutive
# integers that the network learns to predict
train_dur = get_split_dur(dataset_df, "train")
logger.info(
f"Total duration of training split from dataset (in s): {train_dur}",
)
model_name = model_config["name"]
train_transform = transforms.defaults.get_default_transform(
model_name, "train"
)
dataset_params = dataset_config["params"]
train_dataset = ParametricUMAPDataset.from_dataset_path(
dataset_path=dataset_path,
split="train",
subset=subset,
transform=train_transform,
**dataset_params,
)
logger.info(
f"Duration of ParametricUMAPDataset used for training, in seconds: {train_dataset.duration}",
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
shuffle=shuffle,
batch_size=batch_size,
num_workers=num_workers,
)
# ---------------- load validation set (if there is one) -----------------------------------------------------------
if val_step:
transform = transforms.defaults.get_default_transform(
model_name, "eval"
)
val_dataset = ParametricUMAPDataset.from_dataset_path(
dataset_path=dataset_path,
split="val",
transform=transform,
**dataset_params,
)
logger.info(
f"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}",
)
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
)
else:
val_loader = None
model = models.get(
model_name,
model_config,
input_shape=train_dataset.shape,
)
if checkpoint_path is not None:
logger.info(
f"loading checkpoint for {model_name} from path: {checkpoint_path}",
)
model.load_state_dict_from_path(checkpoint_path)
results_model_root = results_path.joinpath(model_name)
results_model_root.mkdir()
ckpt_root = results_model_root.joinpath("checkpoints")
ckpt_root.mkdir()
logger.info(f"training {model_name}")
trainer = get_trainer(
accelerator=trainer_config["accelerator"],
devices=trainer_config["devices"],
max_epochs=num_epochs,
log_save_dir=results_model_root,
ckpt_root=ckpt_root,
ckpt_step=ckpt_step,
)
train_time_start = datetime.datetime.now()
logger.info(f"Training start time: {train_time_start.isoformat()}")
trainer.fit(
model=model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
train_time_stop = datetime.datetime.now()
logger.info(f"Training stop time: {train_time_stop.isoformat()}")
elapsed = train_time_stop - train_time_start
logger.info(f"Elapsed training time: {elapsed}")