"""Function that generates new inferences from trained models in the frame classification family."""
from __future__ import annotations
import json
import logging
import os
import pathlib
from attrs import define
import crowsetta
import joblib
import lightning
import pandas as pd
import numpy as np
import torch.utils.data
from tqdm import tqdm
from .. import common, datapipes, datasets, models, transforms
from ..common import constants, files, validators
from ..datapipes.frame_classification import InferDatapipe
logger = logging.getLogger(__name__)
[docs]
@define
class AnnotationDataFrame:
"""Data class that represents annotations
for an audio file, in a :class:`pandas.DataFrame`.
Used to save annotations that currently can't
be saved with :mod:`crowsetta`, e.g. boundary times.
"""
df: pd.DataFrame
audio_path : str | pathlib.Path
[docs]
def predict_with_frame_classification_model(
model_config: dict,
dataset_config: dict,
trainer_config: dict,
checkpoint_path: str | pathlib.Path,
labelmap_path: str | pathlib.Path,
num_workers: int = 2,
timebins_key:str = "t",
frames_standardizer_path: str | pathlib.Path | None = None,
annot_csv_filename: str | None = None,
output_dir: str | pathlib.Path | None = None,
min_segment_dur: float | None = None,
majority_vote: bool = False,
save_net_outputs: bool = False,
background_label: str = common.constants.DEFAULT_BACKGROUND_LABEL,
) -> None:
"""Make predictions on a dataset with a trained
:class:`~vak.models.FrameClassificationModel`.
Parameters
----------
model_config : dict
Model configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`.
dataset_config: dict
Dataset configuration in a :class:`dict`.
Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`.
trainer_config: dict
Configuration for :class:`lightning.pytorch.Trainer`.
Can be obtained by calling :meth:`vak.config.TrainerConfig.asdict`.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
labelmap_path : str
path to 'labelmap.json' file.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader. Default is 2.
spect_key : str
key for accessing spectrogram in files. Default is 's'.
timebins_key : str
key for accessing vector of time bins in files. Default is 't'.
frames_standardizer_path : str
path to a saved :class:`vak.transforms.FramesStandardizer` object used to standardize (normalize) frames.
If spectrograms were normalized and this is not provided, will give
incorrect results.
annot_csv_filename : str
name of .csv file containing predicted annotations.
Default is None, in which case the name of the dataset .csv
is used, with '.annot.csv' appended to it.
output_dir : str, Path
path to location where .csv containing predicted annotation
should be saved. Defaults to current working directory.
min_segment_dur : float
minimum duration of segment, in seconds. If specified, then
any segment with a duration less than min_segment_dur is
removed from lbl_tb. Default is None, in which case no
segments are removed.
majority_vote : bool
if True, transform segments containing multiple labels
into segments with a single label by taking a "majority vote",
i.e. assign all time bins in the segment the most frequently
occurring label in the segment. This transform can only be
applied if the labelmap contains an 'unlabeled' label,
because unlabeled segments makes it possible to identify
the labeled segments. Default is False.
save_net_outputs : bool
if True, save 'raw' outputs of neural networks
before they are converted to annotations. Default is False.
Typically the output will be "logits"
to which a softmax transform might be applied.
For each item in the dataset--each row in the `dataset_path` .csv--
the output will be saved in a separate file in `output_dir`,
with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a
spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`,
and the network is `TweetyNet`, then the net output file
will be `gy6or6_032312_081416.tweetynet.output.npz`.
"""
# ---- pre-conditions ----------------------------------------------------------------------------------------------
for path, path_name in zip(
(checkpoint_path, labelmap_path, frames_standardizer_path),
("checkpoint_path", "labelmap_path", "frames_standardizer_path"),
):
if path is not None:
if not validators.is_a_file(path):
raise FileNotFoundError(
f"value for ``{path_name}`` not recognized as a file: {path}"
)
model_name = model_config["name"] # we use this var again below
if "window_size" not in dataset_config["params"]:
raise KeyError(
f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table "
f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}"
)
dataset_path = pathlib.Path(dataset_config["path"])
if not dataset_path.exists() or not dataset_path.is_dir():
raise NotADirectoryError(
f"`dataset_path` not found or not recognized as a directory: {dataset_path}"
)
# ---- set up directory to save output -----------------------------------------------------------------------------
# we do this first to make sure we can save things
if output_dir is None:
output_dir = pathlib.Path(os.getcwd())
else:
output_dir = pathlib.Path(output_dir)
if not output_dir.is_dir():
raise NotADirectoryError(
f"value specified for output_dir is not recognized as a directory: {output_dir}"
)
# ---- load what we need to transform data -------------------------------------------------------------------------
if frames_standardizer_path:
logger.info(
f"loading FramesStandardizer from path: {frames_standardizer_path}"
)
frames_standardizer = joblib.load(frames_standardizer_path)
else:
logger.info("Not loading FramesStandardizer, no path was specified")
frames_standardizer = None
logger.info(f"loading labelmap from path: {labelmap_path}")
with labelmap_path.open("r") as f:
labelmap = json.load(f)
# ---------------- load data for prediction ------------------------------------------------------------------------
if "split" in dataset_config["params"]:
split = dataset_config["params"]["split"]
# we do this convoluted thing to avoid 'TypeError: Dataset got multiple values for split`
del dataset_config["params"]["split"]
else:
split = "predict"
# ---- *not* using a built-in dataset ------------------------------------------------------------------------------
if dataset_config["name"] is None:
metadata = datapipes.frame_classification.Metadata.from_dataset_path(
dataset_path
)
dataset_csv_path = dataset_path / metadata.dataset_csv_filename
metadata = (
datapipes.frame_classification.metadata.Metadata.from_dataset_path(
dataset_path
)
)
# we use this below to convert annotations from frames to seconds
frame_dur = metadata.frame_dur
logger.info(
f"loading dataset to predict from csv path: {dataset_csv_path}"
)
pred_dataset = InferDatapipe.from_dataset_path(
dataset_path=dataset_path,
split=split,
window_size=dataset_config["params"]["window_size"],
frames_standardizer=frames_standardizer,
return_padding_mask=True,
)
# ---- *yes* using a built-in dataset ------------------------------------------------------------------------------
else:
# we need "target_type" below when converting predictions to annotations,
# but fail early here if we don't have it
if "target_type" not in dataset_config["params"]:
from ..datasets.biosoundsegbench import VALID_TARGET_TYPES
raise ValueError(
"The dataset table in the configuration file requires a 'target_type' "
"when running predictions on built-in datasets. "
"Please add a key to the table whose value is a valid target type: "
f"{VALID_TARGET_TYPES}"
)
dataset_config["params"]["return_padding_mask"] = True
# next line, required to be true regardless of split so we set it here
dataset_config["params"]["return_frames_path"] = True
pred_dataset = datasets.get(
dataset_config,
split=split,
frames_standardizer=frames_standardizer,
)
# we use this below to convert annotations from frames to seconds
frame_dur = pred_dataset.frame_dur
pred_loader = torch.utils.data.DataLoader(
dataset=pred_dataset,
shuffle=False,
# batch size 1 because each spectrogram reshaped into a batch of windows
batch_size=1,
num_workers=num_workers,
)
logger.info(
f"Duration of a frame in dataset, in seconds: {frame_dur}",
)
# ---------------- do the actual predicting + converting to annotations --------------------------------------------
input_shape = pred_dataset.shape
# if dataset returns spectrogram reshaped into windows,
# throw out the window dimension; just want to tell network (channels, height, width) shape
if len(input_shape) == 4:
input_shape = input_shape[1:]
logger.info(
f"Shape of input to networks used for predictions: {input_shape}"
)
logger.info(f"instantiating model from config:/n{model_name}")
model = models.get(
model_name,
model_config,
num_classes=len(labelmap),
input_shape=input_shape,
labelmap=labelmap,
)
# ---------------- do the actual predicting --------------------------------------------------------------------
logger.info(
f"loading checkpoint for {model_name} from path: {checkpoint_path}"
)
model.load_state_dict_from_path(checkpoint_path)
trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(
save_dir=output_dir
)
trainer = lightning.pytorch.Trainer(
accelerator=trainer_config["accelerator"],
devices=trainer_config["devices"],
logger=trainer_logger,
)
logger.info(f"running predict method of {model_name}")
results = trainer.predict(model, pred_loader)
# TODO: figure out how to overload `on_predict_epoch_end` to return dict
pred_dict = {
frames_path: y_pred
for result in results
for frames_path, y_pred in result.items()
}
# ---------------- set up to convert predictions to annotation files -----------------------------------------------
if dataset_config["name"] is None:
# we assume this default for now -- prep'd datasets are always multi-class frame label
target_type = "multi_frame_labels"
else:
# we made sure we have this above when determining the kind of dataset
target_type = dataset_config["params"]["target_type"]
if isinstance(target_type, str):
pass
elif isinstance(target_type, (list, tuple)):
target_type = tuple(sorted(target_type))
if annot_csv_filename is None:
annot_csv_filename = (
pathlib.Path(dataset_path).stem + constants.ANNOT_CSV_SUFFIX
)
annot_csv_path = pathlib.Path(output_dir).joinpath(annot_csv_filename)
logger.info(f"will save annotations in .csv file: {annot_csv_path}")
# ---------------- converting to annotations ------------------------------------------------------------------
progress_bar = tqdm(pred_loader)
if dataset_config["name"] is None:
# we're using a user-prepped dataset, not a built-in dataset
# so assume we have metadata from above
input_type = (
metadata.input_type
) # we use this to get frame_times inside loop
if input_type == "audio":
audio_format = metadata.audio_format
elif input_type == "spect":
spect_format = metadata.spect_format
else:
input_type = "spect" # assume this for now
spect_format = common.constants.DEFAULT_SPECT_FORMAT
annots = []
logger.info("converting predictions to annotations")
for ind, batch in enumerate(progress_bar):
padding_mask, frames_path = batch["padding_mask"], batch["frames_path"]
padding_mask = np.squeeze(padding_mask)
if isinstance(frames_path, list) and len(frames_path) == 1:
frames_path = frames_path[0]
# we do all this basically to have clear naming below
if target_type == "multi_frame_labels" or target_type == "binary_frame_labels":
class_logits = pred_dict[frames_path]
boundary_logits = None
elif target_type == "boundary_frame_labels":
boundary_logits = pred_dict[frames_path]
class_logits = None
elif target_type == ("boundary_frame_labels", "multi_frame_labels"):
class_logits, boundary_logits = pred_dict[frames_path]
if save_net_outputs:
# not sure if there's a better way to get outputs into right shape;
# can't just call y_pred.reshape() because that basically flattens the whole array first
# meaning we end up with elements in the wrong order
# so instead we convert to sequence then stack horizontally, on column axis
net_output = torch.hstack(class_logits.unbind())
net_output = net_output[:, padding_mask]
net_output = net_output.cpu().numpy()
net_output_path = output_dir.joinpath(
pathlib.Path(frames_path).stem
+ f"{model_name}{constants.NET_OUTPUT_SUFFIX}"
)
np.savez(net_output_path, net_output)
if class_logits is not None:
class_preds = torch.argmax(class_logits, dim=1) # assumes class dimension is 1
class_preds = torch.flatten(class_preds).cpu().numpy()[padding_mask]
if boundary_logits is not None:
boundary_preds = torch.argmax(boundary_logits, dim=1) # assumes class dimension is 1
boundary_preds = torch.flatten(boundary_preds).cpu().numpy()[padding_mask]
if input_type == "audio":
frames, samplefreq = constants.AUDIO_FORMAT_FUNC_MAP[audio_format](
frames_path
)
frame_times = np.arange(frames.shape[-1]) / samplefreq
elif input_type == "spect":
spect_dict = files.spect.load(
frames_path, spect_format=spect_format
)
frame_times = spect_dict[timebins_key]
# audio_fname is used for audio_path attribute of crowsetta.Annotation below
audio_fname = files.spect.find_audio_fname(frames_path)
if target_type == "multi_frame_labels" or target_type == "binary_frame_labels":
if majority_vote or min_segment_dur:
if background_label in labelmap:
background_label = labelmap[background_label]
elif "unlabeled" in labelmap: # some backward compatibility here
background_label = labelmap["unlabeled"]
else:
background_label = 0 # set a default value anyway just to not throw an error
class_preds = transforms.frame_labels.postprocess(
class_preds,
timebin_dur=frame_dur,
min_segment_dur=min_segment_dur,
majority_vote=majority_vote,
background_label=background_label,
)
labels, onsets_s, offsets_s = transforms.frame_labels.to_segments(
class_preds,
labelmap=labelmap,
frame_times=frame_times,
)
if labels is None and onsets_s is None and offsets_s is None:
# handle the case when all time bins are predicted to be unlabeled
# see https://github.com/NickleDave/vak/issues/383
continue
seq = crowsetta.Sequence.from_keyword(
labels=labels, onsets_s=onsets_s, offsets_s=offsets_s
)
annot = crowsetta.Annotation(
seq=seq, notated_path=audio_fname, annot_path=annot_csv_path.name
)
annots.append(annot)
elif target_type == "boundary_frame_labels":
boundary_inds = transforms.frame_labels.boundary_inds_from_boundary_labels(
boundary_preds,
force_boundary_first_ind=True,
)
boundary_times = frame_times[boundary_inds] # fancy indexing
df = pd.DataFrame.from_records({'boundary_time': boundary_times})
annots.append(
AnnotationDataFrame(df=df, audio_path=audio_fname)
)
elif target_type == ("boundary_frame_labels", "multi_frame_labels"):
if majority_vote is False:
logger.warn(
"`majority_vote` was set to False but `vak.predict.predict_with_frame_classification_model` "
"determined that this model predicts both multi-class labels and boundary labels, "
"so `majority_vote` will be set to True (to assign a single label to each segment determined by "
"a boundary)"
)
if background_label in labelmap:
background_label = labelmap[background_label]
elif "unlabeled" in labelmap: # some backward compatibility here
background_label = labelmap["unlabeled"]
else:
background_label = 0 # set a default value anyway just to not throw an error
# Notice here we *always* call post-process, with majority_vote=True
# because we are using boundary labels
class_preds = transforms.frame_labels.postprocess(
frame_labels=class_preds,
timebin_dur=frame_dur,
min_segment_dur=min_segment_dur,
majority_vote=True,
background_label=background_label,
boundary_labels=boundary_preds,
)
labels, onsets_s, offsets_s = transforms.frame_labels.to_segments(
class_preds,
labelmap=labelmap,
frame_times=frame_times,
)
if labels is None and onsets_s is None and offsets_s is None:
# handle the case when all time bins are predicted to be unlabeled
# see https://github.com/NickleDave/vak/issues/383
continue
if labels is None and onsets_s is None and offsets_s is None:
# handle the case when all time bins are predicted to be unlabeled
# see https://github.com/NickleDave/vak/issues/383
continue
seq = crowsetta.Sequence.from_keyword(
labels=labels, onsets_s=onsets_s, offsets_s=offsets_s
)
annot = crowsetta.Annotation(
seq=seq, notated_path=audio_fname, annot_path=annot_csv_path.name
)
annots.append(annot)
if all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots)
generic_seq.to_file(annot_path=annot_csv_path)
elif all([isinstance(annot, AnnotationDataFrame) for annot in annots]):
df_out = []
for sample_num, annot_df in enumerate(annots):
df = annot_df.df
df['audio_path'] = str(annot_df.audio_path)
df['sample_num'] = sample_num
df_out.append(df)
df_out = pd.concat(df_out)
df_out.to_csv(annot_csv_path, index=False)