Source code for openhands.models.encoder.cnn3d

import torch
import torch.nn as nn
import pytorchvideo

AVAILABLE_3D_BACKBONES = [
    "i3d_r50",
    "c2d_r50",
    "csn_r101",
    "r2plus1d_r50",
    "slow_r50",
    "slowfast_r50",
    "slowfast_r101",
    "slowfast_16x8_r101_50_50",
    "x3d_xs",
    "x3d_s",
    "x3d_m",
    "x3d_l",
]


[docs]class CNN3D(nn.Module): """ Initializes the 3D Convolution backbone. **Supported Backbones** - `i3d_r50` - `c2d_r50` - `csn_r101` - `r2plus1d_r5` - `slow_r50` - `slowfast_r50` - `slowfast_r101` - `slowfast_16x8_r101_50_50` - `x3d_xs` - `x3d_s` - `x3d_m` - `x3d_l` Args: in_channels (int): Number of input channels backbone (string): Backbone to use pretrained (bool, optional): Whether to use pretrained Backbone. Default: ``True`` **kwargs (optional): Will be passed to pytorchvideo.models.hub models; """ def __init__(self, in_channels, backbone, pretrained=True, **kwargs): super().__init__() self.backbone = self.get_3d_backbone( backbone, in_channels, pretrained, **kwargs ) self.n_out_features = 400 # list(self.backbone.modules())[-2].out_features
[docs] def forward(self, x): """ forward step """ x = self.backbone(x) return x.transpose(0, 1) # Batch-first
def get_3d_backbone( self, name, in_channels=3, pretrained: bool = False, progress: bool = True, **kwargs ): assert name in AVAILABLE_3D_BACKBONES, "Please use any bonebone from " + str( AVAILABLE_3D_BACKBONES ) import pytorchvideo.models.hub as ptv_hub model = getattr(ptv_hub, name)( pretrained=pretrained, progress=progress, **kwargs ) if in_channels != 3: reshape_conv_input_size(in_channels, model) return model
def reshape_conv_input_size(in_channels, model): """ Change convolution layer to adopt to various input channels """ assert in_channels == 1 or in_channels >= 4 for module in model.modules(): if isinstance(module, nn.Conv3d): break module.in_channels = in_channels weight = module.weight.detach() if in_channels == 1: module.weight = nn.parameter.Parameter(weight.sum(1, keepdim=True)) else: curr_in_channels = module.weight.shape[1] to_concat = torch.Tensor( module.out_channels, module.in_channels - curr_in_channels, *module.kernel_size, ) module.weight = nn.parameter.Parameter( torch.cat([module.weight, to_concat], axis=1) )