Compare commits
3 Commits
dev-cuda
...
c91cf6b53a
| Author | SHA1 | Date | |
|---|---|---|---|
| c91cf6b53a | |||
|
|
c72e34f9dc | ||
| 359f20c3c4 |
@@ -1,12 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
|
||||||
from src.utils.torch import img2tensor, tensor2img
|
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||||
from src.utils.build import build_from_cfg
|
from src.utils.build import build_from_cfg
|
||||||
from src.utils.padder import InputPadder
|
from src.utils.padder import InputPadder
|
||||||
|
|
||||||
@@ -40,7 +39,7 @@ class ModelRunner:
|
|||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
model = model.to(get_device())
|
model = model.to(get_device())
|
||||||
model.eval()
|
model.eval()
|
||||||
self.model = torch.compile(model)
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
def get_vram_available(device: torch.device) -> int:
|
def get_vram_available(device: torch.device) -> int:
|
||||||
@@ -92,19 +91,35 @@ class ImageInterpolator:
|
|||||||
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
||||||
"""
|
"""
|
||||||
logging.debug(f"Reading images: {image1} and {image2}")
|
logging.debug(f"Reading images: {image1} and {image2}")
|
||||||
tensor1 = img2tensor(image1, self.device)
|
tensor1 = img2tensor(image1).to(self.device)
|
||||||
tensor2 = img2tensor(image2, self.device)
|
tensor2 = img2tensor(image2).to(self.device)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||||
)
|
)
|
||||||
|
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
||||||
|
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
||||||
|
h, w = tensor1.shape[2], tensor1.shape[3]
|
||||||
|
logging.debug(f"Interpolating images of size: {h}x{w}")
|
||||||
|
|
||||||
|
scale = self.scale(h, w)
|
||||||
|
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
||||||
|
padding = int(16 / scale)
|
||||||
|
logging.debug(f"Calculated padding: {padding} pixels")
|
||||||
|
padder = InputPadder(tensor1.shape, divisor=padding)
|
||||||
|
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
||||||
|
logging.debug(
|
||||||
|
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor1_padded = tensor1_padded.to(self.device)
|
||||||
|
tensor2_padded = tensor2_padded.to(self.device)
|
||||||
logging.debug("Running model inference for interpolation")
|
logging.debug("Running model inference for interpolation")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.amp.autocast(self.device.type):
|
interpolated = self.model_runner.model(
|
||||||
interpolated = self.model_runner.model(
|
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
||||||
tensor1, tensor2, self.embt
|
)["imgt_pred"]
|
||||||
)["imgt_pred"]
|
|
||||||
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||||
|
(interpolated,) = padder.unpad(interpolated)
|
||||||
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||||
return tensor2img(interpolated.cpu())
|
return tensor2img(interpolated.cpu())
|
||||||
|
|
||||||
|
|||||||
@@ -5,26 +5,23 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def tensor2img(tensor: torch.Tensor):
|
def tensor2img(tensor: torch.Tensor):
|
||||||
tensor = (
|
return (
|
||||||
tensor.mul(255.0)
|
(tensor * 255.0)
|
||||||
.clamp_(0, 255)
|
.detach()
|
||||||
.to(torch.uint8)
|
|
||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
.permute(1, 2, 0)
|
.permute(1, 2, 0)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.clip(0, 255)
|
||||||
|
.astype(np.uint8)
|
||||||
)
|
)
|
||||||
|
|
||||||
return tensor.cpu().numpy()
|
|
||||||
|
|
||||||
|
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
||||||
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
|
|
||||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||||
if img.shape[-1] > 3:
|
if img.shape[-1] > 3:
|
||||||
img = img[:, :, :3]
|
img = img[:, :, :3]
|
||||||
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
|
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||||
if device.type != "cuda":
|
|
||||||
return tensor.float() / 255.0
|
|
||||||
|
|
||||||
return tensor.cuda(non_blocking=True).float().div_(255.0)
|
|
||||||
|
|
||||||
|
|
||||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||||
|
|||||||
Reference in New Issue
Block a user