Refactor image tensor conversion and model inference in interpolator.py and torch.py

This commit is contained in:
Viner Abubakirov
2026-04-13 18:56:00 +05:00
parent 61f8e0abe1
commit 0c871c2314
2 changed files with 21 additions and 33 deletions

View File

@@ -1,11 +1,12 @@
import logging
from pathlib import Path
from typing import Optional
import torch
import numpy as np
from omegaconf import OmegaConf, DictConfig
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.torch import img2tensor, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
@@ -39,7 +40,7 @@ class ModelRunner:
model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device())
model.eval()
self.model = model
self.model = torch.compile(model)
def get_vram_available(device: torch.device) -> int:
@@ -91,35 +92,19 @@ class ImageInterpolator:
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(image1).to(self.device)
tensor2 = img2tensor(image2).to(self.device)
tensor1 = img2tensor(image1, self.device)
tensor2 = img2tensor(image2, self.device)
logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
)
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation")
with torch.no_grad():
with torch.amp.autocast(self.device.type):
interpolated = self.model_runner.model(
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
tensor1, tensor2, self.embt
)["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
return tensor2img(interpolated.cpu())

View File

@@ -5,23 +5,26 @@ import numpy as np
def tensor2img(tensor: torch.Tensor):
return (
(tensor * 255.0)
.detach()
tensor = (
tensor.mul(255.0)
.clamp_(0, 255)
.to(torch.uint8)
.squeeze(0)
.permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3:
img = img[:, :, :3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: