143 lines
5.8 KiB
Python
143 lines
5.8 KiB
Python
import logging
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import numpy as np
|
|
from omegaconf import OmegaConf, DictConfig
|
|
from imageio import imread, imwrite
|
|
|
|
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
|
from src.utils.build import build_from_cfg
|
|
from src.utils.padder import InputPadder
|
|
|
|
|
|
class Anchor:
|
|
def __init__(self, resolution: int, memory: int, memory_bias: int) -> None:
|
|
self.resolution = resolution
|
|
self.memory = memory
|
|
self.memory_bias = memory_bias
|
|
|
|
def __str__(self) -> str:
|
|
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
|
|
|
|
|
|
class ModelRunner:
|
|
def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
|
|
"""Initializes the ModelRunner with configuration and checkpoint.
|
|
|
|
Args:
|
|
config (Path): Path to model configuration in YAML format
|
|
ckpt_path (Path): Path to model checkpoint in .pth format
|
|
device (torch.device): Device to load the model on
|
|
"""
|
|
omega_config = OmegaConf.load(config)
|
|
network_config: DictConfig = omega_config.network
|
|
logging.info(
|
|
f"Loaded network configuration: {network_config} from [{ckpt_path}]"
|
|
)
|
|
model = build_from_cfg(network_config)
|
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
|
model.load_state_dict(checkpoint["state_dict"])
|
|
model = model.to(get_device())
|
|
model.eval()
|
|
self.model = model
|
|
|
|
|
|
def get_vram_available(device: torch.device) -> int:
|
|
"""Returns the available VRAM in bytes."""
|
|
if device.type == "cuda" and torch.cuda.is_available():
|
|
return torch.cuda.get_device_properties(
|
|
device
|
|
).total_memory - torch.cuda.memory_allocated(device)
|
|
elif device.type == "mps" and torch.mps.is_available():
|
|
# MPS does not provide a way to query available memory, so we return a large number to avoid issues
|
|
return torch.mps.recommended_max_memory()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def get_device():
|
|
"""Detects and returns the best available device for PyTorch computation.
|
|
|
|
Returns:
|
|
torch.device: CUDA device if available, MPS device for Apple Silicon if available, otherwise CPU.
|
|
"""
|
|
if torch.cuda.is_available():
|
|
logging.info("Using CUDA-enabled GPU")
|
|
return torch.device("cuda")
|
|
elif torch.mps.is_available():
|
|
logging.info("Using Apple Silicon GPU (MPS)")
|
|
return torch.device("mps")
|
|
logging.info("No GPU available, using CPU")
|
|
return torch.device("cpu")
|
|
|
|
|
|
class ImageInterpolator:
|
|
def __init__(self, device: torch.device, anchor: Anchor, model_runner: ModelRunner):
|
|
self.device = device
|
|
self.anchor = anchor
|
|
self.vram_available = get_vram_available(device)
|
|
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
|
self.model_runner = model_runner
|
|
logging.debug(
|
|
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
|
)
|
|
|
|
def interpolate(self, image1: Path, image2: Path, output_path: Path):
|
|
"""
|
|
Interpolates between two images and saves the result.
|
|
Args:
|
|
image1 (Path): Path to the first input image (only png and jpg formats are supported)
|
|
image2 (Path): Path to the second input image (only png and jpg formats are supported)
|
|
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
|
"""
|
|
logging.debug(f"Reading images: {image1} and {image2}")
|
|
tensor1 = img2tensor(imread(image1)).to(self.device)
|
|
tensor2 = img2tensor(imread(image2)).to(self.device)
|
|
logging.debug(
|
|
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
|
)
|
|
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
|
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
|
h, w = tensor1.shape[2], tensor1.shape[3]
|
|
logging.debug(f"Interpolating images of size: {h}x{w}")
|
|
|
|
scale = self.scale(h, w)
|
|
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
|
padding = int(16 / scale)
|
|
logging.debug(f"Calculated padding: {padding} pixels")
|
|
padder = InputPadder(tensor1.shape, divisor=padding)
|
|
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
|
logging.debug(
|
|
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
|
)
|
|
|
|
tensor1_padded = tensor1_padded.to(self.device)
|
|
tensor2_padded = tensor2_padded.to(self.device)
|
|
logging.debug("Running model inference for interpolation")
|
|
with torch.no_grad():
|
|
interpolated = self.model_runner.model(
|
|
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
|
)["imgt_pred"]
|
|
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
|
(interpolated,) = padder.unpad(interpolated)
|
|
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
|
imwrite(output_path, tensor2img(interpolated.cpu()))
|
|
logging.debug(f"Saved interpolated image to: {output_path}")
|
|
|
|
def scale(self, height: int, width: int) -> float:
|
|
scale = (
|
|
self.anchor.resolution
|
|
/ (height * width)
|
|
* np.sqrt(
|
|
(self.vram_available - self.anchor.memory_bias) / self.anchor.memory
|
|
)
|
|
)
|
|
scale = 1 if scale > 1 else scale
|
|
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
|
if scale < 1:
|
|
logging.info(
|
|
f"Due to the limited VRAM, the video will be scaled by {scale:.2f}"
|
|
)
|
|
return scale
|