import logging import torch import numpy as np def tensor2img(tensor: torch.Tensor): return ( (tensor * 255.0) .detach() .squeeze(0) .permute(1, 2, 0) .cpu() .numpy() .clip(0, 255) .astype(np.uint8) ) def img2tensor(img: np.ndarray) -> torch.Tensor: logging.debug(f"Converting image of shape {img.shape} to tensor") if img.shape[-1] > 3: img = img[:, :, :3] return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: logging.debug("Checking dimensions of input tensors") shape_list = [] result = list(args) for t in args: logging.debug(f"Tensor shape: {t.shape}") shape_list.append(t.shape[2:]) if len(set(shape_list)) > 1: logging.warning( "Inconsistent tensor shapes detected. Resizing tensors to the same shape." ) desired_shape = shape_list[0] logging.info( f"Inconsistent size of input video frames. All frames will be resized to {desired_shape}" ) resize_tensor_list = [] for t in args: resize_tensor_list.append( torch.nn.functional.interpolate( input=t, size=tuple(desired_shape), mode="bilinear", ) ) result = resize_tensor_list return result