vak.predict.frame_classification.predict_with_frame_classification_model

vak.predict.frame_classification.predict_with_frame_classification_model(model_config: dict, dataset_config: dict, trainer_config: dict, checkpoint_path: str | Path, labelmap_path: str | Path, num_workers: int = 2, timebins_key: str = 't', frames_standardizer_path: str | Path | None = None, annot_csv_filename: str | None = None, output_dir: str | Path | None = None, min_segment_dur: float | None = None, majority_vote: bool = False, save_net_outputs: bool = False, background_label: str = 'background') None[source]

Make predictions on a dataset with a trained FrameClassificationModel.

model_configdict

Model configuration in a dict. Can be obtained by calling vak.config.ModelConfig.asdict().

dataset_config: dict

Dataset configuration in a dict. Can be obtained by calling vak.config.DatasetConfig.asdict().

trainer_config: dict

Configuration for lightning.pytorch.Trainer. Can be obtained by calling vak.config.TrainerConfig.asdict().

checkpoint_pathstr

path to directory with checkpoint files saved by Torch, to reload model

labelmap_pathstr

path to ‘labelmap.json’ file.

num_workersint

Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2.

spect_keystr

key for accessing spectrogram in files. Default is ‘s’.

timebins_keystr

key for accessing vector of time bins in files. Default is ‘t’.

frames_standardizer_pathstr

path to a saved vak.transforms.FramesStandardizer object used to standardize (normalize) frames. If spectrograms were normalized and this is not provided, will give incorrect results.

annot_csv_filenamestr

name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv is used, with ‘.annot.csv’ appended to it.

output_dirstr, Path

path to location where .csv containing predicted annotation should be saved. Defaults to current working directory.

min_segment_durfloat

minimum duration of segment, in seconds. If specified, then any segment with a duration less than min_segment_dur is removed from lbl_tb. Default is None, in which case no segments are removed.

majority_votebool

if True, transform segments containing multiple labels into segments with a single label by taking a “majority vote”, i.e. assign all time bins in the segment the most frequently occurring label in the segment. This transform can only be applied if the labelmap contains an ‘unlabeled’ label, because unlabeled segments makes it possible to identify the labeled segments. Default is False.

save_net_outputsbool

if True, save ‘raw’ outputs of neural networks before they are converted to annotations. Default is False. Typically the output will be “logits” to which a softmax transform might be applied. For each item in the dataset–each row in the dataset_path .csv– the output will be saved in a separate file in output_dir, with the extension {MODEL_NAME}.output.npz. E.g., if the input is a spectrogram with spect_path filename gy6or6_032312_081416.npz, and the network is TweetyNet, then the net output file will be gy6or6_032312_081416.tweetynet.output.npz.