vak.predict.predict_.predict

vak.predict.predict_.predict(model_config: dict, dataset_config: dict, trainer_config: dict, checkpoint_path: str | Path, labelmap_path: str | Path, num_workers: int = 2, timebins_key: str = 't', frames_standardizer_path: str | Path | None = None, device: str | None = None, annot_csv_filename: str | None = None, output_dir: str | Path | None = None, min_segment_dur: float | None = None, majority_vote: bool = False, save_net_outputs: bool = False)[source]

Make predictions on a dataset with a trained model.

Parameters:
  • model_config (dict) – Model configuration in a dict, as loaded from a .toml file, and used by the model method from_config.

  • dataset_config (dict) – Dataset configuration in a dict. Can be obtained by calling vak.config.DatasetConfig.asdict().

  • trainer_config (dict) – Configuration for lightning.pytorch.Trainer. Can be obtained by calling vak.config.TrainerConfig.asdict().

  • checkpoint_path (str) – path to directory with checkpoint files saved by Torch, to reload model

  • labelmap_path (str) – path to ‘labelmap.json’ file.

  • window_size (int) – size of windows taken from spectrograms, in number of time bins, shown to neural networks

  • num_workers (int) – Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2.

  • timebins_key (str) – key for accessing vector of time bins in files. Default is ‘t’.

  • frames_standardizer_path (str) – path to a saved vak.transforms.FramesStandardizer object used to standardize (normalize) frames. If spectrograms were normalized and this is not provided, will give incorrect results.

  • annot_csv_filename (str) – name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv is used, with ‘.annot.csv’ appended to it.

  • output_dir (str, Path) – path to location where .csv containing predicted annotation should be saved. Defaults to current working directory.

  • min_segment_dur (float) – minimum duration of segment, in seconds. If specified, then any segment with a duration less than min_segment_dur is removed from lbl_tb. Default is None, in which case no segments are removed.

  • majority_vote (bool) – if True, transform segments containing multiple labels into segments with a single label by taking a “majority vote”, i.e. assign all time bins in the segment the most frequently occurring label in the segment. This transform can only be applied if the labelmap contains an ‘unlabeled’ label, because unlabeled segments makes it possible to identify the labeled segments. Default is False.

  • save_net_outputs (bool) – If True, save ‘raw’ outputs of neural networks before they are converted to annotations. Default is False. Typically the output will be “logits” to which a softmax transform might be applied. For each item in the dataset–each row in the dataset_path .csv– the output will be saved in a separate file in output_dir, with the extension {MODEL_NAME}.output.npz. E.g., if the input is a spectrogram with spect_path filename gy6or6_032312_081416.npz, and the network is TweetyNet, then the net output file will be gy6or6_032312_081416.tweetynet.output.npz.