57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
import logging
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def tensor2img(tensor: torch.Tensor):
|
|
return (
|
|
(tensor * 255.0)
|
|
.detach()
|
|
.squeeze(0)
|
|
.permute(1, 2, 0)
|
|
.cpu()
|
|
.numpy()
|
|
.clip(0, 255)
|
|
.astype(np.uint8)
|
|
)
|
|
|
|
|
|
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
|
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
|
if img.shape[-1] > 3:
|
|
img = img[:, :, :3]
|
|
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
|
|
|
|
|
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
|
logging.debug("Checking dimensions of input tensors")
|
|
shape_list = []
|
|
result = list(args)
|
|
for t in args:
|
|
logging.debug(f"Tensor shape: {t.shape}")
|
|
shape_list.append(t.shape[2:])
|
|
|
|
if len(set(shape_list)) > 1:
|
|
logging.warning(
|
|
"Inconsistent tensor shapes detected. Resizing tensors to the same shape."
|
|
)
|
|
desired_shape = shape_list[0]
|
|
logging.info(
|
|
f"Inconsistent size of input video frames. All frames will be resized to {desired_shape}"
|
|
)
|
|
|
|
resize_tensor_list = []
|
|
for t in args:
|
|
resize_tensor_list.append(
|
|
torch.nn.functional.interpolate(
|
|
input=t,
|
|
size=tuple(desired_shape),
|
|
mode="bilinear",
|
|
)
|
|
)
|
|
|
|
result = resize_tensor_list
|
|
|
|
return result
|