Source code for vak.predict.parametric_umap

"""Function that generates new inferences from trained models in the frame classification family."""

from __future__ import annotations

import logging
import os
import pathlib

import lightning
import torch.utils.data

from .. import datasets, models, transforms
from ..common import validators
from ..datasets.parametric_umap import ParametricUMAPDataset

logger = logging.getLogger(__name__)


[docs] def predict_with_parametric_umap_model( model_config: dict, dataset_config: dict, trainer_config: dict, checkpoint_path, num_workers=2, transform_params: dict | None = None, dataset_params: dict | None = None, timebins_key="t", output_dir=None, ): """Make predictions on a dataset with a trained :class:`vak.models.ParametricUMAPModel`. Parameters ---------- model_config : dict Model configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. dataset_config: dict Dataset configuration in a :class:`dict`. Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. trainer_config: dict Configuration for :class:`lightning.pytorch.Trainer`. Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. transform_params: dict, optional Parameters for data transform. Passed as keyword arguments. Optional, default is None. dataset_params: dict, optional Parameters for dataset. Passed as keyword arguments. Optional, default is None. timebins_key : str key for accessing vector of time bins in files. Default is 't'. annot_csv_filename : str name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv is used, with '.annot.csv' appended to it. output_dir : str, Path path to location where .csv containing predicted annotation should be saved. Defaults to current working directory. """ for path, path_name in zip( (checkpoint_path,), ("checkpoint_path",), ): if path is not None: if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( f"Loading metadata from dataset path: {dataset_path}", ) metadata = datasets.frame_classification.Metadata.from_dataset_path( dataset_path ) if output_dir is None: output_dir = pathlib.Path(os.getcwd()) else: output_dir = pathlib.Path(output_dir) if not output_dir.is_dir(): raise NotADirectoryError( f"value specified for output_dir is not recognized as a directory: {output_dir}" ) # ---------------- load data for prediction ------------------------------------------------------------------------ model_name = model_config["name"] # TODO: fix this when we build transforms into datasets transform_params = { "padding": dataset_config["params"].get( "padding", models.convencoder_umap.get_default_padding(metadata.shape), ) } item_transform = transforms.defaults.get_default_transform( model_name, "predict", transform_params ) dataset_csv_path = dataset_path / metadata.dataset_csv_filename logger.info( f"loading dataset to predict from csv path: {dataset_csv_path}" ) pred_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="predict", transform=item_transform, **dataset_config["params"], ) pred_loader = torch.utils.data.DataLoader( dataset=pred_dataset, shuffle=False, # batch size 1 because each spectrogram reshaped into a batch of windows batch_size=1, num_workers=num_workers, ) # ---------------- do the actual predicting + converting to annotations -------------------------------------------- input_shape = pred_dataset.shape # if dataset returns spectrogram reshaped into windows, # throw out the window dimension; just want to tell network (channels, height, width) shape if len(input_shape) == 4: input_shape = input_shape[1:] logger.info( f"Shape of input to networks used for predictions: {input_shape}" ) logger.info(f"instantiating model from config:/n{model_name}") model = models.get( model_name, model_config, input_shape=input_shape, ) # ---------------- do the actual predicting -------------------------------------------------------------------- logger.info( f"loading checkpoint for {model_name} from path: {checkpoint_path}" ) model.load_state_dict_from_path(checkpoint_path) trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=output_dir) trainer = lightning.pytorch.Trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], logger=trainer_logger ) logger.info(f"running predict method of {model_name}") results = trainer.predict(model, pred_loader) # noqa : F841
# eval_df = pd.DataFrame(row, index=[0]) # eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") # logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") # eval_df.to_csv( # eval_csv_path, index=False # ) # index is False to avoid having "Unnamed: 0" column when loading