Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c871c2314 |
@@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
|
||||
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||
from src.utils.torch import img2tensor, tensor2img
|
||||
from src.utils.build import build_from_cfg
|
||||
from src.utils.padder import InputPadder
|
||||
|
||||
@@ -39,7 +40,7 @@ class ModelRunner:
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
model = model.to(get_device())
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.model = torch.compile(model)
|
||||
|
||||
|
||||
def get_vram_available(device: torch.device) -> int:
|
||||
@@ -91,35 +92,19 @@ class ImageInterpolator:
|
||||
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
||||
"""
|
||||
logging.debug(f"Reading images: {image1} and {image2}")
|
||||
tensor1 = img2tensor(image1).to(self.device)
|
||||
tensor2 = img2tensor(image2).to(self.device)
|
||||
tensor1 = img2tensor(image1, self.device)
|
||||
tensor2 = img2tensor(image2, self.device)
|
||||
logging.debug(
|
||||
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||
)
|
||||
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
||||
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
||||
h, w = tensor1.shape[2], tensor1.shape[3]
|
||||
logging.debug(f"Interpolating images of size: {h}x{w}")
|
||||
|
||||
scale = self.scale(h, w)
|
||||
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
||||
padding = int(16 / scale)
|
||||
logging.debug(f"Calculated padding: {padding} pixels")
|
||||
padder = InputPadder(tensor1.shape, divisor=padding)
|
||||
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
||||
logging.debug(
|
||||
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
||||
)
|
||||
|
||||
tensor1_padded = tensor1_padded.to(self.device)
|
||||
tensor2_padded = tensor2_padded.to(self.device)
|
||||
logging.debug("Running model inference for interpolation")
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device.type):
|
||||
interpolated = self.model_runner.model(
|
||||
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
||||
tensor1, tensor2, self.embt
|
||||
)["imgt_pred"]
|
||||
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||
(interpolated,) = padder.unpad(interpolated)
|
||||
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||
return tensor2img(interpolated.cpu())
|
||||
|
||||
|
||||
@@ -5,23 +5,26 @@ import numpy as np
|
||||
|
||||
|
||||
def tensor2img(tensor: torch.Tensor):
|
||||
return (
|
||||
(tensor * 255.0)
|
||||
.detach()
|
||||
tensor = (
|
||||
tensor.mul(255.0)
|
||||
.clamp_(0, 255)
|
||||
.to(torch.uint8)
|
||||
.squeeze(0)
|
||||
.permute(1, 2, 0)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
||||
|
||||
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
|
||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||
if img.shape[-1] > 3:
|
||||
img = img[:, :, :3]
|
||||
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
|
||||
if device.type != "cuda":
|
||||
return tensor.float() / 255.0
|
||||
|
||||
return tensor.cuda(non_blocking=True).float().div_(255.0)
|
||||
|
||||
|
||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user