Source code for vak.eval.frame_classification

"""Function that evaluates trained models in the frame classification family."""

from __future__ import annotations

import json
import logging
import pathlib
from collections import OrderedDict
from datetime import datetime

import joblib
import lightning
import pandas as pd
import torch.utils.data

from .. import datapipes, datasets, models, transforms
from ..common import validators
from ..datapipes.frame_classification import InferDatapipe

logger = logging.getLogger(__name__)


[docs] def eval_frame_classification_model( model_config: dict, dataset_config: dict, trainer_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, frames_standardizer_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, ) -> None: """Evaluate a trained model. Parameters ---------- model_config : dict Model configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. trainer_config: dict Configuration for :class:`lightning.pytorch.Trainer`. Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str, pathlib.Path Path to 'labelmap.json' file. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. frames_standardizer_path : str, pathlib.Path Path to a saved :class:`vak.transforms.FramesStandardizer` object used to standardize (normalize) frames. If frames were standardized during training and this is not provided, will give incorrect results. Default is None. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is ``vak.transforms.frame_labels.PostProcess`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote``, a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. Notes ----- Note that unlike :func:`~vak.predict.predict`, this function can modify ``labelmap`` so that metrics like edit distance are correctly computed, by converting any string labels in ``labelmap`` with multiple characters to (mock) single-character labels, with :func:`vak.labels.multi_char_labels_to_single_char`. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( (checkpoint_path, labelmap_path, frames_standardizer_path), ("checkpoint_path", "labelmap_path", "frames_standardizer_path"), ): if path is not None: # because `frames_standardizer_path` is optional if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) model_name = model_config["name"] # we use this var again below if "window_size" not in dataset_config["params"]: raise KeyError( f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" ) if not validators.is_a_directory(output_dir): raise NotADirectoryError( f"value for ``output_dir`` not recognized as a directory: {output_dir}" ) # ---- get time for .csv file -------------------------------------------------------------------------------------- timenow = datetime.now().strftime("%y%m%d_%H%M%S") # ---- load what we need to transform data ------------------------------------------------------------------------- if frames_standardizer_path: logger.info( f"loading frames standardizer from path: {frames_standardizer_path}" ) frames_standardizer = joblib.load(frames_standardizer_path) else: logger.info("No `frames_standardizer_path` provided, not standardizing frames.") frames_standardizer = None logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) # ---------------- load data for evaluation ------------------------------------------------------------------------ if "split" in dataset_config["params"]: split = dataset_config["params"]["split"] # we do this convoluted thing to avoid 'TypeError: Dataset got multiple values for split` del dataset_config["params"]["split"] else: split = "test" # ---- *not* using a built-in dataset ------------------------------------------------------------------------------ if dataset_config["name"] is None: dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) # we unpack `frame_dur` to log it, regardless of whether we use it with post_tfm below metadata = datapipes.frame_classification.Metadata.from_dataset_path( dataset_path ) frame_dur = metadata.frame_dur logger.info( f"Duration of a frame in dataset, in seconds: {frame_dur}", ) val_dataset = InferDatapipe.from_dataset_path( dataset_path=dataset_path, split=split, window_size=dataset_config["params"]["window_size"], frames_standardizer=frames_standardizer, return_padding_mask=True, ) # ---- *yes* using a built-in dataset ------------------------------------------------------------------------------# ---- *yes* using a built-in dataset ------------------------------------------------------------------------------ else: dataset_config["params"]["return_padding_mask"] = True val_dataset = datasets.get( dataset_config, split=split, frames_standardizer=frames_standardizer, ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, shuffle=False, # batch size 1 because each spectrogram reshaped into a batch of windows batch_size=1, num_workers=num_workers, ) # ---------------- do the actual evaluating ------------------------------------------------------------------------ input_shape = val_dataset.shape # if dataset returns spectrogram reshaped into windows, # throw out the window dimension; just want to tell network (channels, height, width) shape if len(input_shape) == 4: input_shape = input_shape[1:] if post_tfm_kwargs: post_tfm = transforms.frame_labels.PostProcess( timebin_dur=frame_dur, **post_tfm_kwargs, ) else: post_tfm = None model = models.get( model_name, model_config, num_classes=len(labelmap), input_shape=input_shape, labelmap=labelmap, post_tfm=post_tfm, ) logger.info(f"running evaluation for model: {model_name}") model.load_state_dict_from_path(checkpoint_path) trainer_logger = lightning.pytorch.loggers.TensorBoardLogger( save_dir=output_dir ) trainer = lightning.pytorch.Trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], logger=trainer_logger, ) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] metric_vals = {f"avg_{k}": v for k, v in metric_vals.items()} for metric_name, metric_val in metric_vals.items(): if metric_name.startswith("avg_"): logger.info(f"{metric_name}: {metric_val:0.5f}") # create a "DataFrame" with just one row which we will save as a csv; # the idea is to be able to concatenate csvs from multiple runs of eval row = OrderedDict( [ ("model_name", model_name), ("checkpoint_path", checkpoint_path), ("labelmap_path", labelmap_path), ("frames_standardizer_path", frames_standardizer_path), ("dataset_path", dataset_path), ] ) # TODO: is this still necessary after switching to Lightning? Stop saying "average"? # order metrics by name to be extra sure they will be consistent across runs row.update( sorted( [(k, v) for k, v in metric_vals.items() if k.startswith("avg_")] ) ) # pass index into dataframe, needed when using all scalar values (a single row) # throw away index below when saving to avoid extra column eval_df = pd.DataFrame(row, index=[0]) eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") eval_df.to_csv( eval_csv_path, index=False ) # index is False to avoid having "Unnamed: 0" column when loading