Source code for openhands.models.encoder.graph.st_gcn

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from .graph_utils import GraphWithPartition


class ConvTemporalGraphical(nn.Module):
    """The basic module for applying a graph convolution.
    Args:
        in_channels (int): Number of channels in the input sequence data.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int): Size of the graph convolving kernel.
        t_kernel_size (int): Size of the temporal convolving kernel.
        t_stride (int, optional): Stride of the temporal convolution. Default: 1.
        t_padding (int, optional): Temporal zero-padding added to both sides
            of the input. Default: 0.
        t_dilation (int, optional): Spacing between temporal kernel elements.
            Default: 1.
        bias (bool, optional): If ``True``, adds a learnable bias to the
            output. Default: ``True``.
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)`
            format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}
            , V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)
            ` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]
                `,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        t_kernel_size=1,
        t_stride=1,
        t_padding=0,
        t_dilation=1,
        bias=True,
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(
            in_channels,
            out_channels * kernel_size,
            kernel_size=(t_kernel_size, 1),
            padding=(t_padding, 0),
            stride=(t_stride, 1),
            dilation=(t_dilation, 1),
            bias=bias,
        )

    def forward(self, x, A):
        assert A.size(0) == self.kernel_size

        x = self.conv(x)
        n, kc, t, v = x.size()
        x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
        x = torch.einsum("nkctv,kvw->nctw", (x, A))

        return x.contiguous(), A


class STGCN_BLOCK(nn.Module):
    """
    Applies a spatial temporal graph convolution over an input graph
    sequence.

    Args:
        in_channels (int): Number of channels in the input sequence data.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (tuple): Size of the temporal convolving kernel and
            graph convolving kernel.
        stride (int, optional): Stride of the temporal convolution. Default: 1.
        dropout (int, optional): Dropout rate of the final output. Default: 0.
        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``.
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)`
            format.
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out},
            V)` format.
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V,
            V)` format.
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """
    def __init__(
        self, in_channels, out_channels, kernel_size, stride=1, dropout=0, residual=True
    ):
        super().__init__()

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = ((kernel_size[0] - 1) // 2, 0)

        self.gcn = ConvTemporalGraphical(in_channels, out_channels, kernel_size[1])

        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                (kernel_size[0], 1),
                (stride, 1),
                padding,
            ),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True),
        )

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(stride, 1)),
                nn.BatchNorm2d(out_channels),
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, A):
        res = self.residual(x)
        x, A = self.gcn(x, A)
        x = self.tcn(x) + res

        return self.relu(x), A


[docs]class STGCN(nn.Module): """Spatial temporal graph convolutional network backbone This module is proposed in `Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition <https://arxiv.org/pdf/1801.07455.pdf>`_ Args: in_channels (int): Number of channels in the input data. graph_args (dict): The arguments for building the graph. edge_importance_weighting (bool): If ``True``, adds a learnable importance weighting to the edges of the graph. Default: True. n_out_features (int): Output Embedding dimension. Default: 256. kwargs (dict): Other parameters for graph convolution units. """ def __init__(self, in_channels, graph_args, edge_importance_weighting, n_out_features = 256, **kwargs): super().__init__() graph_args = OmegaConf.to_container(graph_args) self.graph = GraphWithPartition(**graph_args) A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) self.register_buffer("A", A) spatial_kernel_size = A.size(0) temporal_kernel_size = 9 self.n_out_features = n_out_features kernel_size = (temporal_kernel_size, spatial_kernel_size) self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) kwargs0 = {k: v for k, v in kwargs.items() if k != "dropout"} self.st_gcn_networks = nn.ModuleList( ( STGCN_BLOCK(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), STGCN_BLOCK(64, 64, kernel_size, 1, **kwargs), STGCN_BLOCK(64, 64, kernel_size, 1, **kwargs), STGCN_BLOCK(64, 64, kernel_size, 1, **kwargs), STGCN_BLOCK(64, 128, kernel_size, 2, **kwargs), STGCN_BLOCK(128, 128, kernel_size, 1, **kwargs), STGCN_BLOCK(128, 128, kernel_size, 1, **kwargs), STGCN_BLOCK(128, 256, kernel_size, 2, **kwargs), STGCN_BLOCK(256, 256, kernel_size, 1, **kwargs), STGCN_BLOCK(256, self.n_out_features, kernel_size, 1, **kwargs), ) ) if edge_importance_weighting: self.edge_importance = nn.ParameterList( [nn.Parameter(torch.ones(self.A.size())) for i in self.st_gcn_networks] ) else: self.edge_importance = [1] * len(self.st_gcn_networks)
[docs] def forward(self, x): """ Args: x (torch.Tensor): Input tensor of shape :math:`(N, in\_channels, T_{in}, V_{in})` Returns: torch.Tensor: Output embedding of shape :math:`(N, n\_out\_features)` where - :math:`N` is a batch size, - :math:`T_{in}` is a length of input sequence, - :math:`V_{in}` is the number of graph nodes, - :math:`n\_out\_features` is the output embedding dimension. """ N, C, T, V = x.size() x = x.permute(0, 3, 1, 2).contiguous() # NCTV -> NVCT x = x.view(N, V * C, T) x = self.data_bn(x) x = x.view(N, V, C, T) x = x.permute(0, 2, 3, 1).contiguous() # NVCT -> NCTV for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): x, _ = gcn(x, self.A * importance) x = F.avg_pool2d(x, x.size()[2:]) x = x.view(N, -1) return x