Source code for vak.nets.tweetynet
"""TweetyNet model"""
from __future__ import annotations
import torch
from torch import nn
from ..nn.modules import Conv2dTF
[docs]
class TweetyNet(nn.Module):
"""Neural network architecture
that assign labels to time bins
("frames") in spectrogram windows.
as described in
https://elifesciences.org/articles/63853
https://github.com/yardencsGitHub/tweetynet
Cohen, Y., Nicholson, D. A., Sanchioni, A., Mallaber, E. K., Skidanova, V., & Gardner, T. J. (2022).
Automated annotation of birdsong with a neural network that segments spectrograms. Elife, 11, e63853.
Attributes
----------
num_classes : int
Number of classes.
One of the two dimensions of the output.
input_shape : tuple(int)
With dimensions
(channels, num. frequency bins, num. time bins in window).
cnn : torch.nn.Sequential
Convolutional layers of model.
rnn_input_size : int
Size of input to TweetyNet.rnn.
Will be the product of the first two dimensions
of the output of ``TweetyNet.cnn``,
i.e. the number of output channels times
the number of elements in the dimension
that corresponds to frequency bins in the input.
rnn : torch.nn.LSTM
Bidirectional LSTM layer,
that receives output of ``TweetyNet.cnn``.
fc : torch.nn.Linear
Finally fully-connected layer that maps
the output of ``TweetyNet.rnn`` to a
matrix of size (num. time bins in window, num. classes).
Notes
-----
This is the network used by ``vak.models.TweetyNetModel``.
"""
[docs]
def __init__(
self,
num_classes,
num_input_channels=1,
num_freqbins=256,
padding="SAME",
conv1_filters=32,
conv1_kernel_size=(5, 5),
conv2_filters=64,
conv2_kernel_size=(5, 5),
pool1_size=(8, 1),
pool1_stride=(8, 1),
pool2_size=(8, 1),
pool2_stride=(8, 1),
hidden_size=None,
rnn_dropout=0.0,
num_layers=1,
bidirectional=True,
):
"""initialize TweetyNet model
Parameters
----------
num_classes : int
Number of classes to predict, e.g., number of syllable classes in an individual bird's song
num_input_channels: int
Number of channels in input. Typically one, for a spectrogram.
Default is 1.
num_freqbins: int
Number of frequency bins in spectrograms that will be input to model.
Default is 256.
padding : str
type of padding to use, one of {"VALID", "SAME"}. Default is "SAME".
conv1_filters : int
Number of filters in first convolutional layer. Default is 32.
conv1_kernel_size : tuple
Size of kernels, i.e. filters, in first convolutional layer. Default is (5, 5).
conv2_filters : int
Number of filters in second convolutional layer. Default is 64.
conv2_kernel_size : tuple
Size of kernels, i.e. filters, in second convolutional layer. Default is (5, 5).
pool1_size : two element tuple of ints
Size of sliding window for first max pooling layer. Default is (1, 8)
pool1_stride : two element tuple of ints
Step size for sliding window of first max pooling layer. Default is (1, 8)
pool2_size : two element tuple of ints
Size of sliding window for second max pooling layer. Default is (1, 8),
pool2_stride : two element tuple of ints
Step size for sliding window of second max pooling layer. Default is (1, 8)
hidden_size : int
number of features in the hidden state ``h``. Default is None,
in which case ``hidden_size`` is set to the dimensionality of the
output of the convolutional neural network. This default maintains
the original behavior of the network.
rnn_dropout : float
If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer,
with dropout probability equal to dropout. Default: 0
num_layers : int
Number of recurrent layers. Default is 1.
bidirectional : bool
If True, make LSTM bidirectional. Default is True.
"""
super().__init__()
self.num_classes = num_classes
self.num_input_channels = num_input_channels
self.num_freqbins = num_freqbins
self.cnn = nn.Sequential(
Conv2dTF(
in_channels=self.num_input_channels,
out_channels=conv1_filters,
kernel_size=conv1_kernel_size,
padding=padding,
),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=pool1_size, stride=pool1_stride),
Conv2dTF(
in_channels=conv1_filters,
out_channels=conv2_filters,
kernel_size=conv2_kernel_size,
padding=padding,
),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=pool2_size, stride=pool2_stride),
)
# determine number of features in output after stacking channels
# we use the same number of features for hidden states
# note self.num_hidden is also used to reshape output of cnn in self.forward method
N_DUMMY_TIMEBINS = (
256 # some not-small number. This dimension doesn't matter here
)
batch_shape = (
1,
self.num_input_channels,
self.num_freqbins,
N_DUMMY_TIMEBINS,
)
tmp_tensor = torch.rand(batch_shape)
tmp_out = self.cnn(tmp_tensor)
channels_out, freqbins_out = tmp_out.shape[1], tmp_out.shape[2]
self.rnn_input_size = channels_out * freqbins_out
if hidden_size is None:
self.hidden_size = self.rnn_input_size
else:
self.hidden_size = hidden_size
self.rnn = nn.LSTM(
input_size=self.rnn_input_size,
hidden_size=self.hidden_size,
num_layers=num_layers,
dropout=rnn_dropout,
bidirectional=bidirectional,
)
# for self.fc, in_features = hidden_size * 2 because LSTM is bidirectional
# so we get hidden forward + hidden backward as output
self.fc = nn.Linear(
in_features=self.hidden_size * 2, out_features=num_classes
)
[docs]
def forward(self, x):
features = self.cnn(x)
# stack channels, to give tensor shape (batch, rnn_input_size, num time bins)
features = features.view(features.shape[0], self.rnn_input_size, -1)
# switch dimensions for feeding to rnn, to (num time bins, batch size, input size)
features = features.permute(2, 0, 1)
rnn_output, _ = self.rnn(features)
# permute back to (batch, time bins, hidden size) to project features down onto number of classes
rnn_output = rnn_output.permute(1, 0, 2)
logits = self.fc(rnn_output)
# permute yet again so that dimension order is (batch, classes, time steps)
# because this is order that loss function expects
return logits.permute(0, 2, 1)