Source code for vak.prep.split.algorithms.bruteforce

from __future__ import annotations

import logging
import random
from typing import Union

import numpy as np

from .validate import validate_split_durations


[docs] def unique_set_from_labels(labels: list[np.array]): """Helper function to generate the set of unique labels in a list of label arrays. Used for comparison with ``labelset``. """ all_lbls = [lbl for lbl_arr in labels for lbl in lbl_arr] return set(all_lbls)
[docs] def validate_labels(labels: list[np.array], labelset: set) -> None: """Validate that the unique set of label classes from ``labels`` equals ``labelset``.""" uniq_labels = unique_set_from_labels(labels) if not uniq_labels == labelset: if uniq_labels < labelset: missing = labelset - uniq_labels raise ValueError( f"Unable to split using this labelset: {labelset}. " f"The following classes of label do not appear in the list of label arrays: {missing}\n" "To fix, either remove those classes from the labelset, " "or add vocalizations to the dataset containing the missing labels." ) elif uniq_labels > labelset: extra = uniq_labels - labelset raise ValueError( f"Unable to split using this labelset: {labelset}. " f"The following classes of label that are not in labelset are found " f"in the list of label arrays: {extra}\n" "To fix, either add these classes to the labelset, " "or remove the vocalizations from the dataset that contain these labels." ) elif uniq_labels & labelset == set(): raise ValueError( f"Unable to split using this labelset: {labelset}. " f"None of the label classes are found in the set of " f"unique labels from the list of label arrays: {uniq_labels}." )
[docs] def brute_force( durs: list[float], labels: list[np.ndarray], labelset: set, train_dur: Union[int, float], val_dur: Union[int, float], test_dur: Union[int, float], max_iter: int = 5000, ) -> (list[int], list[int], list[int]): """Generate indices that split a dataset into separate training, validation, and test subsets. Finds indices that split (labels, durations) tuples into training, test, and validation sets of specified durations, with the set of unique labels in each dataset equal to the specified labelset. The durations of the datasets created using the returned indices will be *greater than* or equal to the durations specified. Must specify a positive value for one of {train_dur, test_dur}. The other value can be specified as '-1' which is interpreted as "use the remainder of the dataset for this split, after finding indices for the set with a specified duration". Parameters ---------- durs : list Of durations of vocalizations. labels : list Of labels from vocalizations. labelset : set Of labels. train_dur : int, float Target duration for training set, in seconds. val_dur : int, float Target duration for validation set, in seconds. test_dur : int, float Target duration for test set, in seconds. max_iter : int Maximum number of iterations to attempt to find indices. Default is 5000. Returns ------- train_inds, val_inds, test_inds : list Of int, the indices that will split dataset into training, validation, and test subsets. Notes ----- This is a "brute force" algorithm that just randomly assigns indices to a set, and iterates until it finds some partition where each set has instances of all classes of label. Starts by ensuring that each label is represented in each set and then adds files to reach the required durations. """ logger = logging.getLogger(__name__) logger.setLevel("INFO") validate_labels(labels, labelset) sum_durs = sum(durs) train_dur, val_dur, test_dur = validate_split_durations( train_dur, val_dur, test_dur, sum_durs ) target_split_durs = dict( zip(("train", "val", "test"), (train_dur, val_dur, test_dur)) ) if not len(durs) == len(labels): raise ValueError( "length of list of durations did not equal length of list of labels; " "should be same length since " "each duration of a vocalization corresponds to the labels from its annotations.\n" f"Length of durations: {len(durs)}. Length of labels: {len(labels)}" ) iter = 1 all_labels_err = ( "Did not successfully divide data into training, " "validation, and test sets of sufficient duration " f"after {max_iter} iterations. " "Try increasing the total size of the data set." ) # ---- outer loop that repeats until we successfully split our reach max number of iters --------------------------- while 1: # list of indices we use to index into both `durs` and `labels` durs_labels_inds = list( range(len(labels)) ) # we checked len(labels) == len(durs) above # when making `split_inds`, "initialize" the dict with all split names, by using target_split_durs # so we don't get an error when indexing into dict in return statement below split_inds = { split_name: [] for split_name in target_split_durs.keys() } total_split_durs = { split_name: 0 for split_name in target_split_durs.keys() } split_labelsets = { split_name: set() for split_name in target_split_durs.keys() } # list of split 'choices' we use when randomly adding indices to splits choice = [] for split_name in target_split_durs.keys(): if ( target_split_durs[split_name] > 0 or target_split_durs[split_name] == -1 ): choice.append(split_name) # ---- make sure each split has at least one instance of each label -------------------------------------------- for label_from_labelset in sorted(labelset): label_inds = [ ind for ind in durs_labels_inds if label_from_labelset in labels[ind] ] random.shuffle(label_inds) for split_name in target_split_durs.keys(): if ( target_split_durs[split_name] > 0 or target_split_durs[split_name] == -1 ) and label_from_labelset not in split_labelsets[split_name]: try: ind = label_inds.pop() split_inds[split_name].append(ind) total_split_durs[split_name] += durs[ind] split_labelsets[split_name] = split_labelsets[ split_name ].union(set(labels[ind])) durs_labels_inds.remove(ind) except IndexError: if len(label_inds) == 0: logger.debug( "Ran out of elements while dividing dataset into subsets of specified durations." f"Iteration {iter}" ) iter += 1 break # do next iteration else: # something else happened, re-raise error raise for split_name in target_split_durs.keys(): if ( target_split_durs[split_name] > 0 and total_split_durs[split_name] >= target_split_durs[split_name] ): choice.remove(split_name) if len(choice) == 0: finished = True else: finished = False # ---- inner loop that actually does split --------------------------------------------------------------------- random.shuffle(durs_labels_inds) while finished is False: # pop durations off list and append to randomly-chosen # list, either train, val, or test set. # Do this until the total duration for each data set is equal # to or greater than the target duration for each set. try: ind = durs_labels_inds.pop() except IndexError: if len(durs_labels_inds) == 0: logger.debug( "Ran out of elements while dividing dataset into subsets of specified durations." f"Iteration {iter}" ) iter += 1 break # do next iteration else: # something else happened, re-raise error raise which_set = random.randint(0, len(choice) - 1) split_name = choice[which_set] split_inds[split_name].append(ind) total_split_durs[split_name] += durs[ind] if ( target_split_durs[split_name] > 0 and total_split_durs[split_name] >= target_split_durs[split_name] ): choice.remove(split_name) elif target_split_durs[split_name] == -1: # if this split is -1 and other split is already "finished" if (split_name == "test" and "train" not in choice) or ( split_name == "train" and "test" not in choice ): # just add all remaining inds to this split split_inds[split_name].extend(durs_labels_inds) choice.remove(split_name) if len(choice) < 1: # list is empty, we popped off all the choices for split_name in target_split_durs.keys(): if target_split_durs[split_name] > 0: if ( total_split_durs[split_name] < target_split_durs[split_name] ): raise ValueError( "Loop to find splits completed, " f"but total duration of '{split_name}' split, " f"{total_split_durs[split_name]} seconds, " f"is less than target duration specified: {target_split_durs[split_name]} seconds." ) else: finished = True break if iter > max_iter: raise ValueError( "Could not find subsets of sufficient duration in " f"less than {max_iter} iterations." ) # make sure that each split contains all unique labels in labelset if finished is True: for split_name in target_split_durs.keys(): if ( target_split_durs[split_name] > 0 or target_split_durs[split_name] == -1 ): split_labels = [ label for ind in split_inds[split_name] for label in labels[ind] ] split_labelset = set(split_labels) if split_labelset != set(labelset): iter += 1 if iter > max_iter: raise ValueError(all_labels_err) else: logger.debug( f"Set of unique labels in '{split_name}' split did not equal specified labelset. " f"Getting new '{split_name}' split. Iteration: {iter}" ) continue # successfully split break elif finished is False: continue split_inds = { split_name: (inds if inds else None) for split_name, inds in split_inds.items() } return split_inds["train"], split_inds["val"], split_inds["test"]