Source code for vak.eval.parametric_umap

"""Function that evaluates trained models in the parametric UMAP family."""

from __future__ import annotations

import logging
import pathlib
from collections import OrderedDict
from datetime import datetime

import lightning
import pandas as pd
import torch.utils.data

from .. import models
from ..common import validators
from ..datapipes.parametric_umap import Datapipe

logger = logging.getLogger(__name__)


[docs] def eval_parametric_umap_model( model_config: dict, dataset_config: dict, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, batch_size: int, num_workers: int, trainer_config: dict, ) -> None: """Evaluate a trained model. Parameters ---------- model_config : dict Model configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path Path to location where .csv files with evaluation metrics should be saved. batch_size : int Number of samples per batch fed into model. trainer_config: dict Configuration for :class:`lightning.pytorch.Trainer`. Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( (checkpoint_path,), ("checkpoint_path",), ): if path is not None: # because `frames_standardizer_path` is optional if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( f"Loading metadata from dataset path: {dataset_path}", ) if not validators.is_a_directory(output_dir): raise NotADirectoryError( f"value for ``output_dir`` not recognized as a directory: {output_dir}" ) # ---- get time for .csv file -------------------------------------------------------------------------------------- timenow = datetime.now().strftime("%y%m%d_%H%M%S") # ---------------- load data for evaluation ------------------------------------------------------------------------ if "split" in dataset_config["params"]: split = dataset_config["params"]["split"] else: split = "test" model_name = model_config["name"] val_dataset = Datapipe.from_dataset_path( dataset_path=dataset_path, split=split, **dataset_config["params"], ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, ) # ---------------- do the actual evaluating ------------------------------------------------------------------------ model = models.get( model_name, model_config, input_shape=val_dataset.shape, ) logger.info(f"running evaluation for model: {model_name}") model.load_state_dict_from_path(checkpoint_path) trainer_logger = lightning.pytorch.loggers.TensorBoardLogger( save_dir=output_dir ) trainer = lightning.pytorch.Trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], logger=trainer_logger, ) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] for metric_name, metric_val in metric_vals.items(): logger.info(f"{metric_name}: {metric_val:0.5f}") # create a "DataFrame" with just one row which we will save as a csv; # the idea is to be able to concatenate csvs from multiple runs of eval row = OrderedDict( [ ("model_name", model_name), ("checkpoint_path", checkpoint_path), ("dataset_path", dataset_path), ] ) # order metrics by name to be extra sure they will be consistent across runs row.update(sorted([(k, v) for k, v in metric_vals.items()])) # pass index into dataframe, needed when using all scalar values (a single row) # throw away index below when saving to avoid extra column eval_df = pd.DataFrame(row, index=[0]) eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") eval_df.to_csv( eval_csv_path, index=False ) # index is False to avoid having "Unnamed: 0" column when loading