import logging from pathlib import Path import torch import numpy as np from omegaconf import OmegaConf, DictConfig from src.utils import utils from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img from src.utils.build import build_from_cfg from src.utils.padder import InputPadder logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class Anchor: def __init__(self, resolution: int, memory: int, memory_bias: int) -> None: self.resolution = resolution self.memory = memory self.memory_bias = memory_bias def __str__(self) -> str: return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})" class ModelRunner: def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None: """Initializes the ModelRunner with configuration and checkpoint. Args: config (Path): Path to model configuration in YAML format ckpt_path (Path): Path to model checkpoint in .pth format device (torch.device): Device to load the model on """ omega_config = OmegaConf.load(config) network_config: DictConfig = omega_config.network logging.info( f"Loaded network configuration: {network_config} from [{ckpt_path}]" ) model = build_from_cfg(network_config) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint["state_dict"]) model = model.to(get_device()) model.eval() self.model = model def get_vram_available(device: torch.device) -> int: """Returns the available VRAM in bytes.""" if device.type == "cuda" and torch.cuda.is_available(): return torch.cuda.get_device_properties( device ).total_memory - torch.cuda.memory_allocated(device) elif device.type == "mps" and torch.mps.is_available(): # MPS does not provide a way to query available memory, so we return a large number to avoid issues return torch.mps.recommended_max_memory() else: return 1 def get_device(): """Detects and returns the best available device for PyTorch computation. Returns: torch.device: CUDA device if available, MPS device for Apple Silicon if available, otherwise CPU. """ if torch.cuda.is_available(): logging.info("Using CUDA-enabled GPU") return torch.device("cuda") elif torch.mps.is_available(): logging.info("Using Apple Silicon GPU (MPS)") return torch.device("mps") logging.info("No GPU available, using CPU") return torch.device("cpu") class ImageInterpolator: def __init__(self, device: torch.device, anchor: Anchor, model_runner: ModelRunner): self.device = device self.anchor = anchor self.vram_available = get_vram_available(device) self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device) self.model_runner = model_runner logging.debug( f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes" ) def interpolate(self, image1: Path, image2: Path, output_path: Path): logging.debug(f"Reading images: {image1} and {image2}") tensor1 = img2tensor(utils.read(image1)).to(self.device) tensor2 = img2tensor(utils.read(image2)).to(self.device) logging.debug( f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" ) tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2) logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}") h, w = tensor1.shape[2], tensor1.shape[3] logging.debug(f"Interpolating images of size: {h}x{w}") scale = self.scale(h, w) logging.debug(f"Calculated scale factor: {scale:.2f}") padding = int(16 / scale) logging.debug(f"Calculated padding: {padding} pixels") padder = InputPadder(tensor1.shape, divisor=padding) tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2) logging.debug( f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}" ) tensor1_padded = tensor1_padded.to(self.device) tensor2_padded = tensor2_padded.to(self.device) logging.debug("Running model inference for interpolation") with torch.no_grad(): interpolated = self.model_runner.model( tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True )["imgt_pred"] logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") (interpolated,) = padder.unpad(interpolated) logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") utils.write(output_path, tensor2img(interpolated.cpu())) logging.debug(f"Saved interpolated image to: {output_path}") def scale(self, height: int, width: int) -> float: scale = ( self.anchor.resolution / (height * width) * np.sqrt( (self.vram_available - self.anchor.memory_bias) / self.anchor.memory ) ) scale = 1 if scale > 1 else scale scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 if scale < 1: logging.info( f"Due to the limited VRAM, the video will be scaled by {scale:.2f}" ) return scale def main(): config_path = Path("src/config/AMT-G.yaml") ckpt_path = Path("src/pretrained/amt-g.pth") image1_path = Path("source/img0.png") image2_path = Path("source/img1.png") image3_path = Path("source/img2.png") output_path1 = Path("output/interpolated_image1.png") output_path2 = Path("output/interpolated_image2.png") device = get_device() model_runner = ModelRunner(config_path, ckpt_path, device) if device.type in ("cpu", "mps"): if device.type == "mps": logging.warning( "Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance." ) else: logging.warning( "Running on CPU may be very slow. Consider using a GPU for better performance." ) anchor = Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) elif device.type == "cuda": anchor = Anchor( resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 ) else: raise Exception(f"Unsupported device type: {device.type}") interpolator = ImageInterpolator(device, anchor, model_runner) interpolator.interpolate(image1_path, image2_path, output_path1) interpolator.interpolate(image2_path, image3_path, output_path2) if __name__ == "__main__": main()