vak.eval.frame_classification.eval_frame_classification_model#
- vak.eval.frame_classification.eval_frame_classification_model(model_name: str, model_config: dict, dataset_path: str | Path, checkpoint_path: str | Path, labelmap_path: str | Path, output_dir: str | Path, num_workers: int, transform_params: dict | None = None, dataset_params: dict | None = None, split: str = 'test', spect_scaler_path: str | Path | None = None, post_tfm_kwargs: dict | None = None, device: str | None = None) None [source]#
Evaluate a trained model.
- Parameters:
model_name (str) – Model name, must be one of vak.models.registry.MODEL_NAMES.
model_config (dict) – Model configuration in a
dict
, as loaded from a .toml file, and used by the model methodfrom_config
.dataset_path (str, pathlib.Path) – Path to dataset, e.g., a csv file generated by running
vak prep
.checkpoint_path (str, pathlib.Path) – Path to directory with checkpoint files saved by Torch, to reload model
output_dir (str, pathlib.Path) – Path to location where .csv files with evaluation metrics should be saved.
labelmap_path (str, pathlib.Path) – Path to ‘labelmap.json’ file.
num_workers (int) – Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2.
transform_params (dict, optional) – Parameters for data transform. Passed as keyword arguments. Optional, default is None.
dataset_params (dict, optional) – Parameters for dataset. Passed as keyword arguments. Optional, default is None.
split (str) – Split of dataset on which model should be evaluated. One of {‘train’, ‘val’, ‘test’}. Default is ‘test’.
spect_scaler_path (str, pathlib.Path) – Path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. Default is None.
post_tfm_kwargs (dict) – Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is
vak.transforms.frame_labels.PostProcess`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote
, a float value formin_segment_dur
. See the docstring of the transform for more details on these arguments and how they work.device (str) – Device on which to work with model + data. Defaults to ‘cuda’ if torch.cuda.is_available is True.
Notes
Note that unlike
core.predict
, this function can modifylabelmap
so that metrics like edit distance are correctly computed, by converting any string labels inlabelmap
with multiple characters to (mock) single-character labels, withvak.labels.multi_char_labels_to_single_char
.