Source code for vak.models.tweetynet

"""TweetyNet model [1]_.

.. [1] TweetyNet was described in:
   Cohen, Y., Nicholson, D. A., Sanchioni, A., Mallaber, E. K., Skidanova, V., & Gardner, T. J. (2022).
   Automated annotation of birdsong with a neural network that segments spectrograms. eLife 11: e63853.
   Paper: https://elifesciences.org/articles/63853
   Code: https://github.com/yardencsGitHub/tweetynet
"""
from __future__ import annotations

import torch

from .. import metrics, nets
from .decorator import model
from .frame_classification_model import FrameClassificationModel


[docs] @model(family=FrameClassificationModel) class TweetyNet: """TweetyNet model, as described in Cohen, Y., Nicholson, D. A., Sanchioni, A., Mallaber, E. K., Skidanova, V., & Gardner, T. J. (2022). Automated annotation of birdsong with a neural network that segments spectrograms. Elife, 11, e63853. https://elifesciences.org/articles/63853. Code adapted from https://github.com/yardencsGitHub/tweetynet. Attributes ---------- network : vak.nets.TweetyNet Convolutional-bidirectional LSTM neural network architecture. loss: torch.nn.CrossEntropyLoss Standard cross-entropy loss optimizer: torch.optim.Adam Adam optimizer. metrics: dict Mapping string names to the following metrics: ``vak.metrics.Accuracy``, ``vak.metrics.Levenshtein``, ``vak.metrics.CharacterErrorRate``, ``torch.nn.CrossEntropyLoss``. Notes ----- TweetyNet was described in [1]_. ``TweetyNet`` is a type of windowed frame classification model, and this version built into ``vak`` relies on the ``FrameClassificationModel`` class. References ---------- .. [1] Cohen, Y., Nicholson, D. A., Sanchioni, A., Mallaber, E. K., Skidanova, V., & Gardner, T. J. (2022). Automated annotation of birdsong with a neural network that segments spectrograms. eLife 11: e63853. Paper: https://elifesciences.org/articles/63853 Code: https://github.com/yardencsGitHub/tweetynet """ network = nets.TweetyNet loss = torch.nn.CrossEntropyLoss optimizer = torch.optim.Adam metrics = { "acc": metrics.Accuracy, "levenshtein": metrics.Levenshtein, "character_error_rate": metrics.CharacterErrorRate, "loss": torch.nn.CrossEntropyLoss, } default_config = {"optimizer": {"lr": 0.003}}