Files
AMT-Apple/src/utils/torch.py
Viner Abubakirov cf9f0350ce Init code
2026-03-31 09:35:42 +05:00

57 lines
1.4 KiB
Python

import logging
import torch
import numpy as np
def tensor2img(tensor: torch.Tensor):
return (
(tensor * 255.0)
.detach()
.squeeze(0)
.permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
def img2tensor(img: np.ndarray) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3:
img = img[:, :, :3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
logging.debug("Checking dimensions of input tensors")
shape_list = []
result = list(args)
for t in args:
logging.debug(f"Tensor shape: {t.shape}")
shape_list.append(t.shape[2:])
if len(set(shape_list)) > 1:
logging.warning(
"Inconsistent tensor shapes detected. Resizing tensors to the same shape."
)
desired_shape = shape_list[0]
logging.info(
f"Inconsistent size of input video frames. All frames will be resized to {desired_shape}"
)
resize_tensor_list = []
for t in args:
resize_tensor_list.append(
torch.nn.functional.interpolate(
input=t,
size=tuple(desired_shape),
mode="bilinear",
)
)
result = resize_tensor_list
return result