Source code for openhands.models.encoder.graph.sgn

import torch
from torch import nn

import math

#Adopted from: https://github.com/microsoft/SGN

class norm_data(nn.Module):
    def __init__(self, n_joints, dim=3):
        super(norm_data, self).__init__()
        self.bn = nn.BatchNorm1d(dim * n_joints)

    def forward(self, x):
        bs, c, num_joints, step = x.size()
        x = x.view(bs, -1, step)
        x = self.bn(x)
        x = x.view(bs, -1, num_joints, step).contiguous()
        return x


class cnn1x1(nn.Module):
    def __init__(self, dim1=3, dim2=3, bias=True):
        super(cnn1x1, self).__init__()
        self.cnn = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.cnn(x)
        return x


class embed(nn.Module):
    def __init__(self, n_joints, dim=3, dim1=128, norm=True, bias=False):
        super(embed, self).__init__()

        self.cnn = nn.Sequential(
            norm_data(n_joints, dim) if norm else nn.Identity(),
            cnn1x1(dim, 64, bias=bias),
            nn.ReLU(),
            cnn1x1(64, dim1, bias=bias),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.cnn(x)
        return x


class local(nn.Module):
    def __init__(self, dim1=3, dim2=3, bias=False):
        super(local, self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d((1, 20))
        self.cnn1 = nn.Conv2d(dim1, dim1, kernel_size=(1, 3), padding=(0, 1), bias=bias)
        self.bn1 = nn.BatchNorm2d(dim1)
        self.relu = nn.ReLU()
        self.cnn2 = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(dim2)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x1):
        x1 = self.maxpool(x1)
        x = self.cnn1(x1)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x


class gcn_spa(nn.Module):
    def __init__(self, in_feature, out_feature, bias=False):
        super(gcn_spa, self).__init__()
        self.bn = nn.BatchNorm2d(out_feature)
        self.relu = nn.ReLU()
        self.w = cnn1x1(in_feature, out_feature, bias=False)
        self.w1 = cnn1x1(in_feature, out_feature, bias=bias)

    def forward(self, x1, g):
        x = x1.permute(0, 3, 2, 1).contiguous()
        x = g.matmul(x)
        x = x.permute(0, 3, 2, 1).contiguous()
        x = self.w(x) + self.w1(x1)
        x = self.relu(self.bn(x))
        return x


class compute_g_spa(nn.Module):
    def __init__(self, dim1=64 * 3, dim2=64 * 3, bias=False):
        super(compute_g_spa, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2
        self.g1 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.g2 = cnn1x1(self.dim1, self.dim2, bias=bias)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1):

        g1 = self.g1(x1).permute(0, 3, 2, 1).contiguous()
        g2 = self.g2(x1).permute(0, 3, 1, 2).contiguous()
        g3 = g1.matmul(g2)
        g = self.softmax(g3)
        return g


[docs]class SGN(nn.Module): """ SGN model proposed in `Semantics-Guided Neural Networks for Efficient Skeleton-Based Human Action Recognition <https://arxiv.org/pdf/1904.01189.pdf>`_ Note: The model supports inputs only with fixed number of frames. Args: n_frames (int): Number of frames in the input sequence. num_points (int): Number of spatial points in a graph. in_channels (int): Number of channels in the input data. Default: 2. bias (bool): Whether to use bias or not. Default: ``True``. """ def __init__(self, n_frames, num_points, in_channels=2, bias=True): super(SGN, self).__init__() self.dim1 = 256 self.n_frames = n_frames self.n_joints = num_points self.register_buffer( "spa_oh", self.one_hot(1, self.n_joints, self.n_frames).permute(0, 3, 2, 1) ) self.register_buffer( "tem_oh", self.one_hot(1, self.n_frames, self.n_joints).permute(0, 3, 1, 2) ) self.tem_embed = embed( self.n_joints, self.n_frames, 64 * 4, norm=False, bias=bias ) self.spa_embed = embed(self.n_joints, self.n_joints, 64, norm=False, bias=bias) self.joint_embed = embed(self.n_joints, in_channels, 64, norm=True, bias=bias) self.dif_embed = embed(self.n_joints, in_channels, 64, norm=True, bias=bias) self.maxpool = nn.AdaptiveMaxPool2d((1, 1)) self.cnn = local(self.dim1, self.dim1 * 2, bias=bias) self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias) self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias) self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias) self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias) self.n_out_features = self.dim1 * 2 for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2.0 / n)) nn.init.constant_(self.gcn1.w.cnn.weight, 0) nn.init.constant_(self.gcn2.w.cnn.weight, 0) nn.init.constant_(self.gcn3.w.cnn.weight, 0)
[docs] def forward(self, input): """ Args: input (torch.Tensor): Input tensor of shape :math:`(N, in\_channels, T_{in}, V_{in})` Returns: torch.Tensor: Output embedding of shape :math:`(N, n\_out\_features)` where - :math:`N` is a batch size, - :math:`T_{in}` is a length of input sequence, - :math:`V_{in}` is the number of graph nodes, - :math:`n\_out\_features` is the output embedding dimension. """ # B, C, T, V input = input.permute(0, 2, 3, 1) bs, step, num_joints, dim = input.size() input = input.permute(0, 3, 2, 1).contiguous() dif = input[:, :, :, 1:] - input[:, :, :, 0:-1] dif = torch.cat([dif.new(bs, dif.size(1), num_joints, 1).zero_(), dif], dim=-1) pos = self.joint_embed(input) tem1 = self.tem_embed(self.tem_oh).repeat(bs, 1, 1, 1) spa1 = self.spa_embed(self.spa_oh).repeat(bs, 1, 1, 1) dif = self.dif_embed(dif) dy = pos + dif # Joint-level Module input = torch.cat([dy, spa1], 1) g = self.compute_g1(input) input = self.gcn1(input, g) input = self.gcn2(input, g) input = self.gcn3(input, g) # Frame-level Module input = input + tem1 input = self.cnn(input) # Classification head output = self.maxpool(input) output = torch.flatten(output, 1) return output
[docs] def one_hot(self, bs, spa, tem): """ get one-hot encodings """ y = torch.arange(spa).unsqueeze(-1) y_onehot = torch.FloatTensor(spa, spa) y_onehot.zero_() y_onehot.scatter_(1, y, 1) y_onehot = y_onehot.unsqueeze(0).unsqueeze(0) y_onehot = y_onehot.repeat(bs, tem, 1, 1) return y_onehot