3 Commits

Author SHA1 Message Date
c91cf6b53a Merge pull request 'dev' (#2) from dev into main
Reviewed-on: #2
2026-04-03 18:28:31 +05:00
Viner Abubakirov
c72e34f9dc checkout presets.py from dev 2026-04-02 18:31:54 +05:00
359f20c3c4 Merge pull request 'dev' (#1) from dev into main
Reviewed-on: #1
2026-04-02 12:17:05 +05:00
2 changed files with 33 additions and 21 deletions

View File

@@ -1,12 +1,11 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
from src.utils.torch import img2tensor, tensor2img from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder from src.utils.padder import InputPadder
@@ -40,7 +39,7 @@ class ModelRunner:
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device()) model = model.to(get_device())
model.eval() model.eval()
self.model = torch.compile(model) self.model = model
def get_vram_available(device: torch.device) -> int: def get_vram_available(device: torch.device) -> int:
@@ -92,19 +91,35 @@ class ImageInterpolator:
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported) output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
""" """
logging.debug(f"Reading images: {image1} and {image2}") logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(image1, self.device) tensor1 = img2tensor(image1).to(self.device)
tensor2 = img2tensor(image2, self.device) tensor2 = img2tensor(image2).to(self.device)
logging.debug( logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
) )
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation") logging.debug("Running model inference for interpolation")
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast(self.device.type):
interpolated = self.model_runner.model( interpolated = self.model_runner.model(
tensor1, tensor2, self.embt tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
)["imgt_pred"] )["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
return tensor2img(interpolated.cpu()) return tensor2img(interpolated.cpu())

View File

@@ -5,26 +5,23 @@ import numpy as np
def tensor2img(tensor: torch.Tensor): def tensor2img(tensor: torch.Tensor):
tensor = ( return (
tensor.mul(255.0) (tensor * 255.0)
.clamp_(0, 255) .detach()
.to(torch.uint8)
.squeeze(0) .squeeze(0)
.permute(1, 2, 0) .permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
) )
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor") logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3: if img.shape[-1] > 3:
img = img[:, :, :3] img = img[:, :, :3]
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: