Source code for openhands.datasets.isolated.base

import torch
import torch.nn.functional as F
import torchvision
import pickle
import albumentations as A
import numpy as np
import pandas as pd
import os, warnings
from ..video_transforms import *
from ..data_readers import *

[docs]class BaseIsolatedDataset(torch.utils.data.Dataset): """ This module provides the datasets for Isolated Sign Language Classification. Do not instantiate this class """ lang_code = None # Get language from here: # https://iso639-3.sil.org/code_tables/639/data?title=&field_iso639_cd_st_mmbrshp_639_1_tid=94671&name_3=sign+language&field_iso639_element_scope_tid=All&field_iso639_language_type_tid=All&items_per_page=200 ASSETS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets") def __init__( self, root_dir, split_file=None, class_mappings_file_path=None, normalized_class_mappings_file=None, splits=["train"], modality="rgb", transforms="default", cv_resize_dims=(264, 264), pose_use_confidence_scores=False, pose_use_z_axis=False, inference_mode=False, only_metadata=False, # Does not load data files if `True` multilingual=False, languages=None, language_set=None, # Windowing seq_len=1, # No. of frames per window num_seq=1, # No. of windows ): super().__init__() self.split_file = split_file self.root_dir = root_dir self.class_mappings_file_path = class_mappings_file_path self.splits = splits self.modality = modality self.multilingual = multilingual self.seq_len = seq_len self.num_seq = num_seq self.languages=languages self.language_set=language_set self.normalized_class_mappings_file = normalized_class_mappings_file if normalized_class_mappings_file: df = pd.read_csv(normalized_class_mappings_file, na_filter=False) # In German, "null" means "zero" self.normalized_class_mappings = {df["actual_gloss"][i]: df["normalized_gloss"][i] for i in range(len(df))} # TODO: Also store reverse mapping for inference in original lang self.glosses = [] self.read_glosses() if not self.glosses: raise RuntimeError("Unable to read glosses list") print(f"Found {len(self.glosses)} classes in {splits} splits") self.gloss_to_id = {gloss: i for i, gloss in enumerate(self.glosses)} self.id_to_gloss = {i: gloss for i, gloss in enumerate(self.glosses)} self.inference_mode = inference_mode self.only_metadata = only_metadata if not only_metadata: self.data = [] if inference_mode: # Will have null labels self.enumerate_data_files(self.root_dir) else: self.read_original_dataset() if not self.data: raise RuntimeError("No data found") self.cv_resize_dims = cv_resize_dims self.pose_use_confidence_scores = pose_use_confidence_scores self.pose_use_z_axis = pose_use_z_axis if "rgb" in modality: self.in_channels = 3 if modality == "rgbd": self.in_channels += 1 self.__getitem = self.__getitem_video elif modality == "pose": self.in_channels = 4 if not self.pose_use_confidence_scores: self.in_channels -= 1 if not self.pose_use_z_axis: self.in_channels -= 1 self.__getitem = self.__getitem_pose else: exit(f"ERROR: Modality `{modality}` not supported") self.setup_transforms(modality, transforms) def setup_transforms(self, modality, transforms): if "rgb" in modality: if transforms == "default": albumentation_transforms = A.Compose( [ A.ShiftScaleRotate( shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5 ), A.ChannelDropout(p=0.1), A.RandomRain(p=0.1), A.GridDistortion(p=0.3), ] ) self.transforms = torchvision.transforms.Compose( [ Albumentations2DTo3D(albumentation_transforms), NumpyToTensor(), RandomTemporalSubsample(16), torchvision.transforms.Resize( (self.cv_resize_dims[0], self.cv_resize_dims[1]) ), torchvision.transforms.RandomCrop( (self.cv_resize_dims[0], self.cv_resize_dims[1]) ), torchvision.transforms.RandomHorizontalFlip(p=0.5), # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), TCHW2CTHW(), ] ) elif transforms: self.transforms = transforms else: self.transforms = torchvision.transforms.Compose( [ NumpyToTensor(), # THWC2CTHW(), THWC2TCHW(), torchvision.transforms.Resize( (self.cv_resize_dims[0], self.cv_resize_dims[1]) ), TCHW2CTHW(), ] ) elif "pose" in modality: if transforms == "default": transforms = None self.transforms = transforms @property def num_class(self): return len(self.glosses)
[docs] def read_glosses(self): """ Implement this method to construct `self.glosses[]` """ raise NotImplementedError
[docs] def read_original_dataset(self): """ Implement this method to read (video_name/video_folder, classification_label) into self.data[] """ raise NotImplementedError
[docs] def enumerate_data_files(self, dir): """ Lists the video files from given directory. - If pose modality, generate `.pkl` files for all videos in folder. - If no videos present, check if some `.pkl` files already exist """ files = list_all_videos(dir) if self.modality == "pose": holistic = None pose_files = [] for video_file in files: pose_file = os.path.splitext(video_file)[0] + ".pkl" if not os.path.isfile(pose_file): # If pose is not cached, generate and store it. if not holistic: # Create MediaPipe instance from ..pipelines.generate_pose import MediaPipePoseGenerator holistic = MediaPipePoseGenerator() # Dump keypoints frames = load_frames_from_video(video_file) holistic.generate_keypoints_for_frames(frames, pose_file) pose_files.append(pose_file) if not pose_files: pose_files = list_all_files(dir, extensions=[".pkl"]) files = pose_files if not files: raise RuntimeError(f"No files found in {dir}") self.data = [(f, -1) for f in files]
# -1 means invalid label_id def __len__(self): return len(self.data)
[docs] def load_pose_from_path(self, path): """ Load dumped pose keypoints. Should contain: { "keypoints" of shape (T, V, C), "confidences" of shape (T, V) } """ pose_data = pickle.load(open(path, "rb")) return pose_data
[docs] def read_video_data(self, index): """ Extend this method for dataset-specific formats """ video_path = self.data[index][0] label = self.data[index][1] imgs = load_frames_from_video(video_path) return imgs, label, video_name
def __getitem_video(self, index): if self.inference_mode: imgs, label, video_id = super().read_video_data(index) else: imgs, label, video_id = self.read_video_data(index) # imgs shape: (T, H, W, C) if self.transforms is not None: imgs = self.transforms(imgs) return { "frames": imgs, "label": torch.tensor(label, dtype=torch.long), "file": video_id, } @staticmethod def collate_fn(batch_list): if "num_windows" in batch_list[0]: # Padding not required for windowed models frames=[x["frames"] for x in batch_list] else: max_frames = max([x["frames"].shape[1] for x in batch_list]) # Pad the temporal dimension to `max_frames` for all videos # Assumes each instance of shape: (C, T, V) # TODO: Handle videos (C,T,H,W) frames = [ F.pad(x["frames"], (0, 0, 0, max_frames - x["frames"].shape[1], 0, 0)) for i, x in enumerate(batch_list) ] frames = torch.stack(frames, dim=0) labels = [x["label"] for i, x in enumerate(batch_list)] labels = torch.stack(labels, dim=0) return dict(frames=frames, labels=labels, files=[x["file"] for x in batch_list], dataset_names=[x["dataset_name"] for x in batch_list]) def read_pose_data(self, index): label = self.data[index][1] if self.inference_mode: pose_path = self.data[index][0] else: video_name = self.data[index][0] video_path = os.path.join(self.root_dir, video_name) # print("--------------279",self.root_dir) # print("---------280",video_name) # If `video_path` is folder of frames from which pose was dumped, keep it as it is. # Otherwise, just remove the video extension pose_path = ( video_path if os.path.isdir(video_path) else os.path.splitext(video_path)[0] ) pose_path = pose_path + ".pkl" #print(pose_path) pose_data = self.load_pose_from_path(pose_path) pose_data["label"] = torch.tensor(label, dtype=torch.long) if self.multilingual: # if `ConcatDataset` is used, it has extra entries for following: pose_data["lang_code"] = self.data[index][2] pose_data["dataset_name"] = self.data[index][3] return pose_data, pose_path def __getitem_pose(self, index): """ Returns C - num channels T - num frames V - num vertices """ data, path = self.read_pose_data(index) # imgs shape: (T, V, C) kps = data["keypoints"] scores = data["confidences"] if not self.pose_use_z_axis: kps = kps[:, :, :2] if self.pose_use_confidence_scores: kps = np.concatenate([kps, np.expand_dims(scores, axis=-1)], axis=-1) kps = np.asarray(kps, dtype=np.float32) data = { "frames": torch.tensor(kps).permute(2, 0, 1), # (C, T, V) "label": data["label"], "file": path, "lang_code": data["lang_code"] if self.multilingual else None, # Required for lang_token prepend "dataset_name": data["dataset_name"] if self.multilingual else None, # Required to calc dataset-wise accuracy } if self.transforms is not None: data = self.transforms(data) if self.seq_len > 1 and self.num_seq > 1: data["num_windows"] = self.num_seq kps = data["frames"].permute(1, 2, 0).numpy() # CTV->TVC if kps.shape[0] < self.seq_len * self.num_seq: pad_kps = np.zeros( ((self.seq_len * self.num_seq) - kps.shape[0], *kps.shape[1:]) ) kps = np.concatenate([pad_kps, kps]) elif kps.shape[0] > self.seq_len * self.num_seq: kps = kps[: self.seq_len * self.num_seq, ...] SL = kps.shape[0] clips = [] i = 0 while i + self.seq_len <= SL: clips.append(torch.tensor(kps[i : i + self.seq_len, ...], dtype=torch.float32)) i += self.seq_len t_seq = torch.stack(clips, 0) data["frames"] = t_seq.permute(0, 3, 1, 2) # WTVC->WCTV return data def __getitem__(self, index): return self.__getitem(index)