import logging from pathlib import Path from typing import TYPE_CHECKING import tqdm from src.config import presets from src.utils.fs import FileSystem from src.utils.video import VideoMaker from src.interpolator import ( ImageInterpolator, Anchor, get_device, get_vram_available, ModelRunner, ) if TYPE_CHECKING: import torch def performing_warning_message(device: "torch.device"): if device.type in ("cpu", "mps"): if device.type == "mps": logging.warning( "Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance." ) else: logging.warning( "Running on CPU may be very slow. Consider using a GPU for better performance." ) elif device.type == "cuda": pass else: raise Exception(f"Unsupported device type: {device.type}") def init_fs(base_path: Path) -> FileSystem: fs = FileSystem(base_path) fs.clear_directory(fs.frames_path) fs.clear_directory(fs.interpolated_path) fs.clear_directory(fs.moved_path) fs.clear_directory(fs.video_part_path) return fs def init_video_maker() -> VideoMaker: return VideoMaker() def init_device() -> "torch.device": device = get_device() performing_warning_message(device) vram_available = get_vram_available(device) logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB") return device def init_anchor(device: "torch.device") -> Anchor: if device.type in ("cpu", "mps"): return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) elif device.type == "cuda": return Anchor( resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 ) else: raise Exception(f"Unsupported device type: {device.type}") def init_model_runner( config: Path, checkpoint_path: Path, device: "torch.device" ) -> ModelRunner: return ModelRunner(config, checkpoint_path, device) def init_interpolator( model_runner: ModelRunner, device: "torch.device" ) -> ImageInterpolator: anchor = init_anchor(device) return ImageInterpolator(device, anchor, model_runner) class InterpolationPipeline: def __init__( self, config: Path, checkpoint_path: Path, base_path: Path, ): self.fs = init_fs(base_path) self.video_maker = init_video_maker() self.device = init_device() self.model_runner = init_model_runner(config, checkpoint_path, self.device) self.interpolator = init_interpolator(self.model_runner, self.device) def run(self, video_path: Path, output_video: str): prev_frame_path = None frame_count = 0 part = 0 source_frame_length = 0 chunk_seconds = 10 length = self.video_maker.get_video_duration(video_path) last_part_seconds = 1 if length % chunk_seconds else 0 total_parts = int(length // chunk_seconds) + last_part_seconds fps = self.video_maker.get_fps(video_path) logging.info(f"Video FPS: {fps}") fps *= 2 # Doubling FPS for frame_paths in self.video_maker.video_to_frames_generator( video_path, self.fs.frames_path, chunk_seconds ): logging.info(f"Processing frames: {len(frame_paths)}") if prev_frame_path is not None: img1 = prev_frame_path[-1] img2 = frame_paths[0] output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png" self.interpolator.interpolate(img1, img2, output_path) logging.debug(f"Interpolated image saved to: {output_path}") self._merge_frames_to_video( self.fs.video_part_path / f"video_{part:08d}.mp4", fps, source_frame_length=source_frame_length, ) logging.info(f"Finished processing part {part:08d}") frame_count += 1 part += 1 for i in tqdm.tqdm( range(len(frame_paths) - 1), desc=f"Processing video frames {part + 1} / {total_parts}", ): img1 = frame_paths[i] img2 = frame_paths[i + 1] output_path = self.fs.interpolated_path / f"img_{i:08d}.png" self.interpolator.interpolate(img1, img2, output_path) logging.debug(f"Interpolated image saved to: {output_path}") frame_count += 1 source_frame_length = len(frame_paths) prev_frame_path = frame_paths self._merge_frames_to_video( self.fs.video_part_path / f"video_{part:08d}.mp4", fps, source_frame_length=source_frame_length, ) logging.info(f"Finished processing part {part:08d}") self._merge_video_parts(self.fs.output_path / output_video) logging.info( f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" ) def _merge_frames_to_video( self, output_video: Path, fps: float, source_frame_length: int = 0 ): self._move_frames(source_frame_length) self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) def _merge_video_parts(self, output_video: Path): self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) self.fs.clear_directory(self.fs.video_part_path) def _move_frames(self, source_frame_length: int = 0): self.fs.clear_directory(self.fs.moved_path) src_frames = sorted(self.fs.frames_path.glob("*.png")) interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png")) index = 0 for i in range(source_frame_length): moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png" src_frames[i].rename(moved_frame_path) index += 1 if i < len(interpolated_frames): moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png" interpolated_frames[i].rename(moved_interpolated_path) index += 1 logging.info( f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}" ) def main(preset: presets.Preset = presets.LARGE): base_path = Path("output") video_path = Path("example/video.mp4") output_video = "interpolated_video.mp4" pipeline = InterpolationPipeline( config=preset.config, checkpoint_path=preset.checkpoint, base_path=base_path, ) pipeline.run(video_path, output_video) if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) main()