import logging from pathlib import Path from typing import TYPE_CHECKING from cv2 import imwrite import tqdm from src.config import presets from src.utils.fs import FileSystem from src.utils.video import VideoMaker from src.interpolator import ( ImageInterpolator, Anchor, get_device, get_vram_available, ModelRunner, ) if TYPE_CHECKING: import torch import numpy as np def performing_warning_message(device: "torch.device"): if device.type in ("cpu", "mps"): if device.type == "mps": logging.warning( "Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance." ) else: logging.warning( "Running on CPU may be very slow. Consider using a GPU for better performance." ) elif device.type == "cuda": pass else: raise Exception(f"Unsupported device type: {device.type}") def init_fs(base_path: Path) -> FileSystem: fs = FileSystem(base_path) fs.clear_directory(fs.frames_path) fs.clear_directory(fs.interpolated_path) fs.clear_directory(fs.moved_path) fs.clear_directory(fs.video_part_path) return fs def init_video_maker() -> VideoMaker: return VideoMaker() def init_device() -> "torch.device": device = get_device() performing_warning_message(device) vram_available = get_vram_available(device) logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB") return device def init_anchor(device: "torch.device") -> Anchor: if device.type in ("cpu", "mps"): return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) elif device.type == "cuda": return Anchor( resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 ) else: raise Exception(f"Unsupported device type: {device.type}") def init_model_runner( config: Path, checkpoint_path: Path, device: "torch.device" ) -> ModelRunner: return ModelRunner(config, checkpoint_path, device) def init_interpolator( model_runner: ModelRunner, device: "torch.device" ) -> ImageInterpolator: anchor = init_anchor(device) return ImageInterpolator(device, anchor, model_runner) class InterpolationPipeline: def __init__( self, config: Path, checkpoint_path: Path, base_path: Path, ): self.fs = init_fs(base_path) self.video_maker = init_video_maker() self.device = init_device() self.model_runner = init_model_runner(config, checkpoint_path, self.device) self.interpolator = init_interpolator(self.model_runner, self.device) def run(self, video_path: Path, output_video: str): prev_frames = tuple() interpolated_frames = [] part = 0 source_frame_length = 0 chunk_seconds = 10 length = self.video_maker.get_video_duration(video_path) last_part_seconds = 1 if length % chunk_seconds else 0 total_parts = int(length // chunk_seconds) + last_part_seconds fps = self.video_maker.get_fps(video_path) logging.info(f"Video FPS: {fps}") fps *= 2 # Doubling FPS for frames in self.video_maker.video_to_frames_generator( video_path, self.fs.frames_path, chunk_seconds ): logging.info(f"Processing frames: {len(frames)}") if prev_frames: img1 = prev_frames[-1] img2 = frames[0] img1_2 = self.interpolator.interpolate(img1, img2) interpolated_frames.append(img1_2) self.fs.clear_directory(self.fs.moved_path) self._save_images(prev_frames, interpolated_frames) self._merge_frames_to_video( self.fs.video_part_path / f"video_{part:08d}.mp4", fps, source_frame_length, ) interpolated_frames = [] logging.info(f"Finished processing part {part:08d}") part += 1 for i in tqdm.tqdm( range(len(frames) - 1), desc=f"Processing video frames {part + 1} / {total_parts}", ): img1 = frames[i] img2 = frames[i + 1] img1_2 = self.interpolator.interpolate(img1, img2) interpolated_frames.append(img1_2) source_frame_length = len(frames) prev_frames = frames self.fs.clear_directory(self.fs.moved_path) self._save_images(prev_frames, interpolated_frames) self._merge_frames_to_video( self.fs.video_part_path / f"video_{part:08d}.mp4", fps, source_frame_length, ) self.fs.clear_directory(self.fs.moved_path) logging.info(f"Finished processing part {part:08d}") self._merge_video_parts(self.fs.output_path / output_video) logging.info( f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" ) def _save_images( self, source: tuple["np.ndarray", ...], interpolated: list["np.ndarray"], ): logging.info("Saving images...") index = 0 for i, frame in enumerate(source): name = self.fs.moved_path / f"img_{index:08d}.png" index += 1 imwrite(name, frame) if i < len(interpolated): name = self.fs.moved_path / f"img_{index:08d}.png" index += 1 imwrite(name, interpolated[i]) logging.info("Success...") def _merge_frames_to_video( self, output_video: Path, fps: float, source_frame_length: int = 0 ): self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) def _merge_video_parts(self, output_video: Path): self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) self.fs.clear_directory(self.fs.video_part_path) def runner( base_path: Path, video_path: Path, output_video: str, preset: presets.Preset = presets.LARGE, ): pipeline = InterpolationPipeline( config=preset.config, checkpoint_path=preset.checkpoint, base_path=base_path, ) pipeline.run(video_path, output_video) def main(): import argparse logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) parser = argparse.ArgumentParser() parser.add_argument("-b", "--base_path", help="Base path", default="output") parser.add_argument( "-v", "--video_path", help="Video path", default="example/video.mp4" ) parser.add_argument( "-o", "--output", help="Output video name (example: 'interpolated_video.mp4')", default="interpolated_video.mp4", ) parser.add_argument( "-p", "--preset", help="Model preset", choices=["small", "large", "global"], default="global", ) args = parser.parse_args() runner( base_path=Path(args.base_path), video_path=Path(args.video_path), output_video=args.output, preset=getattr(presets, args.preset.upper()), ) if __name__ == "__main__": main()