import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import AttentionBlock

[docs]class RNNClassifier(nn.Module): """ RNN head for classification. Args: n_features (int): Number of features in the input. num_class (int): Number of class for classification. rnn_type (str): GRU or LSTM. Default: ``GRU``. hidden_size (str): Hidden dim to use for RNN. Default: 512. num_layers (int): Number of layers of RNN to use. Default: 1. bidirectional (bool): Whether to use bidirectional RNN or not. Default: ``True``. use_attention (bool): Whether to use attenion for pooling or not. Default: ``False``. """ def __init__( self, n_features, num_class, rnn_type="GRU", hidden_size=512, num_layers=1, bidirectional=True, use_attention=False, ): super().__init__() self.use_attention = use_attention self.rnn = getattr(nn, rnn_type)( input_size=n_features, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, ) rnn_out_size = hidden_size * 2 if bidirectional else hidden_size if self.use_attention: self.attn_block = AttentionBlock(hidden_size=rnn_out_size) self.fc = nn.Linear(rnn_out_size, num_class)
[docs] def forward(self, x): """ Args: x (torch.Tensor): Input tensor of shape: (batch_size, T, n_features) returns: torch.Tensor: logits for classification. """ self.rnn.flatten_parameters() out, _ = self.rnn(x) if self.use_attention: out = self.fc(self.attn_block(out)) else: # out = torch.max(out, dim=1).values out = self.fc(out[:, -1, :]) return out