diff --git a/main.py b/main.py index 487808a..1ab35c7 100644 --- a/main.py +++ b/main.py @@ -1,92 +1,25 @@ import logging -import subprocess from pathlib import Path -from typing import Generator +from typing import TYPE_CHECKING -import cv2 -from tqdm import tqdm -from time import perf_counter +import tqdm -from interpolator import get_device -from interpolator import ImageInterpolator -from interpolator import ModelRunner, Anchor - - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +from src.utils.fs import FileSystem +from src.utils.video import VideoMaker +from interpolator import ( + ImageInterpolator, + Anchor, + get_device, + get_vram_available, + ModelRunner, ) -from pathlib import Path + +if TYPE_CHECKING: + import torch -def move_images(src_dir: str, interpolated_dir: str, output_dir: str): - src_dir = Path(src_dir) - interpolated_dir = Path(interpolated_dir) - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - index = 0 - src_frames = sorted(src_dir.glob("img_*.png")) - interp_frames = sorted(interpolated_dir.glob("img_*.png")) - for i in range(len(src_frames)): - output_frame = output_dir / f"img_{index:08d}.png" - src_frames[i].rename(output_frame) - index += 1 - - if i < len(interp_frames): - output_interp = output_dir / f"img_{index:08d}.png" - interp_frames[i].rename(output_interp) - index += 1 - - -def video_frames_to_disk_generator( - video_path: str | Path, - output_dir: str | Path, - chunk_seconds: int = 10 -) -> Generator[tuple[Path, ...], None, None]: - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - cap = cv2.VideoCapture(str(video_path)) - - if not cap.isOpened(): - raise ValueError(f"Cannot open video: {video_path}") - - fps = cap.get(cv2.CAP_PROP_FPS) - frames_per_chunk = int(fps * chunk_seconds) - - frame_index = 0 - - while True: - paths = [] - - for _ in range(frames_per_chunk): - ret, frame = cap.read() - if not ret: - cap.release() - return - - frame_path = output_dir / f"img_{frame_index:08d}.png" - cv2.imwrite(str(frame_path), frame) - - paths.append(frame_path) - frame_index += 1 - - yield tuple(paths) - - -def main(): - start = perf_counter() - logging.info("Starting video interpolation process") - config_path = Path("src/config/AMT-G.yaml") - ckpt_path = Path("src/pretrained/amt-g.pth") - video_path = Path("example/video.mp4") - output_dir = Path("output/frames") - output_interpolated_dir = Path("output/interpolated") - output_interpolated_dir.mkdir(parents=True, exist_ok=True) - - device = get_device() - model_runner = ModelRunner(config_path, ckpt_path, device) +def performing_warning_message(device: "torch.device"): if device.type in ("cpu", "mps"): if device.type == "mps": logging.warning( @@ -96,87 +29,165 @@ def main(): logging.warning( "Running on CPU may be very slow. Consider using a GPU for better performance." ) - anchor = Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) elif device.type == "cuda": - anchor = Anchor( + pass + else: + raise Exception(f"Unsupported device type: {device.type}") + + +def init_fs(base_path: Path) -> FileSystem: + fs = FileSystem(base_path) + fs.clear_directory(fs.frames_path) + fs.clear_directory(fs.interpolated_path) + fs.clear_directory(fs.moved_path) + fs.clear_directory(fs.video_part_path) + return fs + + +def init_video_maker() -> VideoMaker: + return VideoMaker() + + +def init_device() -> "torch.device": + device = get_device() + performing_warning_message(device) + vram_available = get_vram_available(device) + logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB") + return device + + +def init_anchor(device: "torch.device") -> Anchor: + if device.type in ("cpu", "mps"): + return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) + elif device.type == "cuda": + return Anchor( resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 ) else: raise Exception(f"Unsupported device type: {device.type}") - - interpolator = ImageInterpolator(device, anchor, model_runner) - - loaded_time = perf_counter() - start - logging.info(f"Model loaded and initialized in {loaded_time:.2f} seconds") - - prev_frame_path = None - frame_count = 0 - for frame_paths in video_frames_to_disk_generator(video_path, output_dir): - logging.info(f"Processing frames: {len(frame_paths)}") - - if prev_frame_path is not None: - img1 = prev_frame_path[-1] - img2 = frame_paths[0] - output_path = output_interpolated_dir / f"img_{frame_count:08d}.png" - interpolator.interpolate(img1, img2, output_path) - logging.debug(f"Interpolated image saved to: {output_path}") - frame_count += 1 - for i in tqdm(range(len(frame_paths) - 1), desc="Interpolating frames"): - img1 = frame_paths[i] - img2 = frame_paths[i + 1] - output_path = output_interpolated_dir / f"img_{frame_count:08d}.png" - interpolator.interpolate(img1, img2, output_path) - logging.debug(f"Interpolated image saved to: {output_path}") - frame_count += 1 - prev_frame_path = frame_paths - total_time = perf_counter() - start - logging.info(f"Video interpolation completed in {total_time:.2f} seconds") -def builder(): - frames_dir = "output/frames" - interpolated_dir = "output/interpolated" - moved_dir = "output/moved" - video_path = "example/video.mp4" - output_video = "output/interpolated_video.mp4" - move_images(frames_dir, interpolated_dir, moved_dir) - - cap = cv2.VideoCapture(video_path) - - if not cap.isOpened(): - raise ValueError("Cannot open original video") - - fps = cap.get(cv2.CAP_PROP_FPS) - cmd = [ - "ffmpeg", - "-y", - "-framerate", str(fps * 2), - "-i", f"{moved_dir}/img_%08d.png", - "-i", video_path, - "-c:v", "libx264", - "-c:a", "copy", - "-shortest", - output_video, - ] - logging.info("Running ffmpeg command to build final video: " + " ".join(cmd)) - subprocess.run(cmd, check=True) +def init_model_runner( + config: Path, checkpoint_path: Path, device: "torch.device" +) -> ModelRunner: + return ModelRunner(config, checkpoint_path, device) -def cleanup(): - import os - import shutil - frames_dir = "output/frames" - interpolated_dir = "output/interpolated" - moved_dir = "output/moved" - os.makedirs(frames_dir, exist_ok=True) - os.makedirs(interpolated_dir, exist_ok=True) - os.makedirs(moved_dir, exist_ok=True) - shutil.rmtree(frames_dir) - shutil.rmtree(interpolated_dir) - shutil.rmtree(moved_dir) +def init_interpolator( + model_runner: ModelRunner, device: "torch.device" +) -> ImageInterpolator: + anchor = init_anchor(device) + return ImageInterpolator(device, anchor, model_runner) + + +class InterpolationPipeline: + def __init__( + self, + config: Path, + checkpoint_path: Path, + base_path: Path, + ): + self.fs = init_fs(base_path) + self.video_maker = init_video_maker() + self.device = init_device() + self.model_runner = init_model_runner(config, checkpoint_path, self.device) + self.interpolator = init_interpolator(self.model_runner, self.device) + + def run(self, video_path: Path, output_video: str): + prev_frame_path = None + frame_count = 0 + part = 0 + source_frame_length = 0 + chunk_seconds = 10 + length = self.video_maker.get_video_duration(video_path) + last_part_seconds = 1 if length % chunk_seconds else 0 + total_parts = int(length // chunk_seconds) + last_part_seconds + fps = self.video_maker.get_fps(video_path) + logging.info(f"Video FPS: {fps}") + fps *= 2 # Doubling FPS + for frame_paths in self.video_maker.video_to_frames_generator( + video_path, self.fs.frames_path, chunk_seconds + ): + logging.info(f"Processing frames: {len(frame_paths)}") + if prev_frame_path is not None: + img1 = prev_frame_path[-1] + img2 = frame_paths[0] + output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png" + self.interpolator.interpolate(img1, img2, output_path) + logging.debug(f"Interpolated image saved to: {output_path}") + self._merge_frames_to_video( + self.fs.video_part_path / f"video_{part:08d}.mp4", + fps, + source_frame_length=source_frame_length, + ) + logging.info(f"Finished processing part {part:08d}") + frame_count += 1 + part += 1 + for i in tqdm.tqdm( + range(len(frame_paths) - 1), + desc=f"Processing video frames {part} / {total_parts}", + ): + img1 = frame_paths[i] + img2 = frame_paths[i + 1] + output_path = self.fs.interpolated_path / f"img_{i:08d}.png" + self.interpolator.interpolate(img1, img2, output_path) + logging.debug(f"Interpolated image saved to: {output_path}") + frame_count += 1 + source_frame_length = len(frame_paths) + prev_frame_path = frame_paths + + self._merge_frames_to_video( + self.fs.video_part_path / f"video_{part:08d}.mp4", + fps, + source_frame_length=source_frame_length, + ) + logging.info(f"Finished processing part {part:08d}") + self._merge_video_parts(self.fs.output_path / output_video) + logging.info( + f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" + ) + + def _merge_frames_to_video( + self, output_video: Path, fps: float, source_frame_length: int = 0 + ): + self._move_frames(source_frame_length) + self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) + + def _merge_video_parts(self, output_video: Path): + self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) + self.fs.clear_directory(self.fs.video_part_path) + + def _move_frames(self, source_frame_length: int = 0): + self.fs.clear_directory(self.fs.moved_path) + src_frames = sorted(self.fs.frames_path.glob("*.png")) + interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png")) + index = 0 + for i in range(source_frame_length): + moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png" + src_frames[i].rename(moved_frame_path) + index += 1 + if i < len(interpolated_frames): + moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png" + interpolated_frames[i].rename(moved_interpolated_path) + index += 1 + logging.info( + f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}" + ) + + +def main(): + config = Path("src/config/AMT-G.yaml") + checkpoint_path = Path("src/pretrained/amt-g.pth") + base_path = Path("output") + video_path = Path("example/video.mp4") + output_video = "interpolated_video.mp4" + + pipeline = InterpolationPipeline(config, checkpoint_path, base_path) + pipeline.run(video_path, output_video) + if __name__ == "__main__": - cleanup() + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) main() - builder() - cleanup() diff --git a/src/utils/fs.py b/src/utils/fs.py new file mode 100644 index 0000000..d6e3016 --- /dev/null +++ b/src/utils/fs.py @@ -0,0 +1,53 @@ +from pathlib import Path + + +class FileSystem: + SOURCE_PATH = "source" + OUTPUT_PATH = "output" + FRAMES_PATH = "frames" + INTERPOLATED_PATH = "interpolated" + MOVED_PATH = "moved" + VIDEO_PART_PATH = "video_parts" + + def __init__(self, base_path: Path): + self.base_path = base_path + self.base_path.mkdir(parents=True, exist_ok=True) + + def create_directory(self, dir_name: str) -> Path: + """Creates a directory under the base path.""" + dir_path = self.base_path / dir_name + dir_path.mkdir(parents=True, exist_ok=True) + return dir_path + + def clear_directory(self, dir_path: Path): + """Clears all files in the specified directory.""" + for item in dir_path.iterdir(): + if item.is_file(): + item.unlink() + elif item.is_dir(): + self.clear_directory(item) + item.rmdir() + + @property + def source_path(self) -> Path: + return self.create_directory(self.SOURCE_PATH) + + @property + def output_path(self) -> Path: + return self.create_directory(self.OUTPUT_PATH) + + @property + def frames_path(self) -> Path: + return self.create_directory(self.FRAMES_PATH) + + @property + def interpolated_path(self) -> Path: + return self.create_directory(self.INTERPOLATED_PATH) + + @property + def moved_path(self) -> Path: + return self.create_directory(self.MOVED_PATH) + + @property + def video_part_path(self) -> Path: + return self.create_directory(self.VIDEO_PART_PATH) \ No newline at end of file diff --git a/src/utils/video.py b/src/utils/video.py new file mode 100644 index 0000000..cd295fe --- /dev/null +++ b/src/utils/video.py @@ -0,0 +1,96 @@ +import logging +from pathlib import Path + +import cv2 +from typing import Generator + + +class VideoMaker: + def images_to_video( + self, + images_path: Path, + output_path: Path, + fps: float, + image_numerator: str = "img_%08d.png", + ): + """Converts a sequence of images to a video using ffmpeg.""" + cmd = f"ffmpeg -framerate {fps} -i {images_path / image_numerator} -c:v libx264 -pix_fmt yuv420p {output_path}" + logging.info(f"Running command: {cmd}") + result = self.run_command(cmd) + if result != 0: + logging.error(f"Failed to create video. Command returned {result}") + + def concatenate_videos( + self, + videos_path: Path, + output_path: Path, + video_numerator: str = "video_%08d.mp4", + ): + """Concatenates a sequence of videos using ffmpeg.""" + cmd = f"ffmpeg -f concat -safe 0 -i <(for f in {videos_path / video_numerator}; do echo \"file '$f'\"; done) -c copy {output_path}" + logging.info(f"Running command: {cmd}") + result = self.run_command(cmd) + if result != 0: + logging.error(f"Failed to concatenate videos. Command returned {result}") + + def get_fps(self, video_path: Path) -> float: + """Gets the frames per second (FPS) of a video.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + logging.debug(f"FPS of video {video_path}: {fps}") + return fps + + def get_video_duration(self, video_path: Path) -> float: + """Gets the duration of a video in seconds.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + cap.release() + duration = frame_count / fps + logging.debug(f"Duration of video {video_path}: {duration:.2f} seconds") + return duration + + def run_command(self, cmd: str) -> int: + import subprocess + + try: + subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL) + return 0 + except subprocess.CalledProcessError as e: + logging.error(f"Command failed with error: {e}") + return e.returncode + + def video_to_frames_generator(self, video_path: Path, output_dir: Path, chunk_seconds: int = 10) -> Generator[tuple[Path, ...], None, None]: + """Extracts frames from a video and saves them to disk, yielding paths to the saved frames.""" + + cap = cv2.VideoCapture(str(video_path)) + + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) + frames_per_chunk = int(fps * chunk_seconds) + + frame_index = 0 + + while True: + paths = [] + + for _ in range(frames_per_chunk): + ret, frame = cap.read() + if not ret: + cap.release() + return + + frame_path = output_dir / f"img_{frame_index:08d}.png" + cv2.imwrite(str(frame_path), frame) + + paths.append(frame_path) + frame_index += 1 + + yield tuple(paths) \ No newline at end of file