Init code
This commit is contained in:
56
src/utils/torch.py
Normal file
56
src/utils/torch.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tensor2img(tensor: torch.Tensor):
|
||||
return (
|
||||
(tensor * 255.0)
|
||||
.detach()
|
||||
.squeeze(0)
|
||||
.permute(1, 2, 0)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
|
||||
|
||||
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||
if img.shape[-1] > 3:
|
||||
img = img[:, :, :3]
|
||||
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
|
||||
|
||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||
logging.debug("Checking dimensions of input tensors")
|
||||
shape_list = []
|
||||
result = list(args)
|
||||
for t in args:
|
||||
logging.debug(f"Tensor shape: {t.shape}")
|
||||
shape_list.append(t.shape[2:])
|
||||
|
||||
if len(set(shape_list)) > 1:
|
||||
logging.warning(
|
||||
"Inconsistent tensor shapes detected. Resizing tensors to the same shape."
|
||||
)
|
||||
desired_shape = shape_list[0]
|
||||
logging.info(
|
||||
f"Inconsistent size of input video frames. All frames will be resized to {desired_shape}"
|
||||
)
|
||||
|
||||
resize_tensor_list = []
|
||||
for t in args:
|
||||
resize_tensor_list.append(
|
||||
torch.nn.functional.interpolate(
|
||||
input=t,
|
||||
size=tuple(desired_shape),
|
||||
mode="bilinear",
|
||||
)
|
||||
)
|
||||
|
||||
result = resize_tensor_list
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user