From 0c871c2314d8b03c20559a7a6e4162b86bc2311e Mon Sep 17 00:00:00 2001 From: Viner Abubakirov Date: Mon, 13 Apr 2026 18:56:00 +0500 Subject: [PATCH] Refactor image tensor conversion and model inference in interpolator.py and torch.py --- src/interpolator.py | 33 +++++++++------------------------ src/utils/torch.py | 21 ++++++++++++--------- 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/interpolator.py b/src/interpolator.py index 49d0b94..668acf6 100644 --- a/src/interpolator.py +++ b/src/interpolator.py @@ -1,11 +1,12 @@ import logging from pathlib import Path +from typing import Optional import torch import numpy as np from omegaconf import OmegaConf, DictConfig -from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img +from src.utils.torch import img2tensor, tensor2img from src.utils.build import build_from_cfg from src.utils.padder import InputPadder @@ -39,7 +40,7 @@ class ModelRunner: model.load_state_dict(checkpoint["state_dict"]) model = model.to(get_device()) model.eval() - self.model = model + self.model = torch.compile(model) def get_vram_available(device: torch.device) -> int: @@ -91,35 +92,19 @@ class ImageInterpolator: output_path (Path): Path to save the interpolated image (only png and jpg formats are supported) """ logging.debug(f"Reading images: {image1} and {image2}") - tensor1 = img2tensor(image1).to(self.device) - tensor2 = img2tensor(image2).to(self.device) + tensor1 = img2tensor(image1, self.device) + tensor2 = img2tensor(image2, self.device) logging.debug( f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" ) - tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2) - logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}") - h, w = tensor1.shape[2], tensor1.shape[3] - logging.debug(f"Interpolating images of size: {h}x{w}") - scale = self.scale(h, w) - logging.debug(f"Calculated scale factor: {scale:.2f}") - padding = int(16 / scale) - logging.debug(f"Calculated padding: {padding} pixels") - padder = InputPadder(tensor1.shape, divisor=padding) - tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2) - logging.debug( - f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}" - ) - - tensor1_padded = tensor1_padded.to(self.device) - tensor2_padded = tensor2_padded.to(self.device) logging.debug("Running model inference for interpolation") with torch.no_grad(): - interpolated = self.model_runner.model( - tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True - )["imgt_pred"] + with torch.amp.autocast(self.device.type): + interpolated = self.model_runner.model( + tensor1, tensor2, self.embt + )["imgt_pred"] logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") - (interpolated,) = padder.unpad(interpolated) logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") return tensor2img(interpolated.cpu()) diff --git a/src/utils/torch.py b/src/utils/torch.py index 8d621da..a456680 100644 --- a/src/utils/torch.py +++ b/src/utils/torch.py @@ -5,23 +5,26 @@ import numpy as np def tensor2img(tensor: torch.Tensor): - return ( - (tensor * 255.0) - .detach() + tensor = ( + tensor.mul(255.0) + .clamp_(0, 255) + .to(torch.uint8) .squeeze(0) .permute(1, 2, 0) - .cpu() - .numpy() - .clip(0, 255) - .astype(np.uint8) ) + return tensor.cpu().numpy() -def img2tensor(img: np.ndarray) -> torch.Tensor: + +def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor: logging.debug(f"Converting image of shape {img.shape} to tensor") if img.shape[-1] > 3: img = img[:, :, :3] - return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 + tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) + if device.type != "cuda": + return tensor.float() / 255.0 + + return tensor.cuda(non_blocking=True).float().div_(255.0) def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: