Files
AMT-Apple/interpolator.py

136 lines
5.4 KiB
Python

import logging
from pathlib import Path
import torch
import numpy as np
from omegaconf import OmegaConf, DictConfig
from src.utils import utils
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
class Anchor:
def __init__(self, resolution: int, memory: int, memory_bias: int) -> None:
self.resolution = resolution
self.memory = memory
self.memory_bias = memory_bias
def __str__(self) -> str:
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
class ModelRunner:
def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
"""Initializes the ModelRunner with configuration and checkpoint.
Args:
config (Path): Path to model configuration in YAML format
ckpt_path (Path): Path to model checkpoint in .pth format
device (torch.device): Device to load the model on
"""
omega_config = OmegaConf.load(config)
network_config: DictConfig = omega_config.network
logging.info(
f"Loaded network configuration: {network_config} from [{ckpt_path}]"
)
model = build_from_cfg(network_config)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device())
model.eval()
self.model = model
def get_vram_available(device: torch.device) -> int:
"""Returns the available VRAM in bytes."""
if device.type == "cuda" and torch.cuda.is_available():
return torch.cuda.get_device_properties(
device
).total_memory - torch.cuda.memory_allocated(device)
elif device.type == "mps" and torch.mps.is_available():
# MPS does not provide a way to query available memory, so we return a large number to avoid issues
return torch.mps.recommended_max_memory()
else:
return 1
def get_device():
"""Detects and returns the best available device for PyTorch computation.
Returns:
torch.device: CUDA device if available, MPS device for Apple Silicon if available, otherwise CPU.
"""
if torch.cuda.is_available():
logging.info("Using CUDA-enabled GPU")
return torch.device("cuda")
elif torch.mps.is_available():
logging.info("Using Apple Silicon GPU (MPS)")
return torch.device("mps")
logging.info("No GPU available, using CPU")
return torch.device("cpu")
class ImageInterpolator:
def __init__(self, device: torch.device, anchor: Anchor, model_runner: ModelRunner):
self.device = device
self.anchor = anchor
self.vram_available = get_vram_available(device)
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
self.model_runner = model_runner
logging.debug(
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
)
def interpolate(self, image1: Path, image2: Path, output_path: Path):
logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(utils.read(image1)).to(self.device)
tensor2 = img2tensor(utils.read(image2)).to(self.device)
logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
)
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation")
with torch.no_grad():
interpolated = self.model_runner.model(
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
)["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
utils.write(output_path, tensor2img(interpolated.cpu()))
logging.debug(f"Saved interpolated image to: {output_path}")
def scale(self, height: int, width: int) -> float:
scale = (
self.anchor.resolution
/ (height * width)
* np.sqrt(
(self.vram_available - self.anchor.memory_bias) / self.anchor.memory
)
)
scale = 1 if scale > 1 else scale
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
if scale < 1:
logging.info(
f"Due to the limited VRAM, the video will be scaled by {scale:.2f}"
)
return scale