diff --git a/main.py b/main.py index c4bf696..170171b 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import logging from pathlib import Path from typing import TYPE_CHECKING +from cv2 import imwrite import tqdm from src.config import presets @@ -18,6 +19,7 @@ from src.interpolator import ( if TYPE_CHECKING: import torch + import numpy as np def performing_warning_message(device: "torch.device"): @@ -53,7 +55,7 @@ def init_device() -> "torch.device": device = get_device() performing_warning_message(device) vram_available = get_vram_available(device) - logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB") + logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB") return device @@ -95,10 +97,9 @@ class InterpolationPipeline: self.interpolator = init_interpolator(self.model_runner, self.device) def run(self, video_path: Path, output_video: str): - prev_frame_path = None - frame_count = 0 + prev_frames = tuple() + interpolated_frames: list["np.ndarray"] = [] part = 0 - source_frame_length = 0 chunk_seconds = 10 length = self.video_maker.get_video_duration(video_path) last_part_seconds = 1 if length % chunk_seconds else 0 @@ -106,41 +107,38 @@ class InterpolationPipeline: fps = self.video_maker.get_fps(video_path) logging.info(f"Video FPS: {fps}") fps *= 2 # Doubling FPS - for frame_paths in self.video_maker.video_to_frames_generator( + width, height = self.video_maker.get_size(video_path) + for frames in self.video_maker.video_to_frames_generator( video_path, self.fs.frames_path, chunk_seconds ): - logging.info(f"Processing frames: {len(frame_paths)}") - if prev_frame_path is not None: - img1 = prev_frame_path[-1] - img2 = frame_paths[0] - output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png" - self.interpolator.interpolate(img1, img2, output_path) - logging.debug(f"Interpolated image saved to: {output_path}") - self._merge_frames_to_video( - self.fs.video_part_path / f"video_{part:08d}.mp4", - fps, - source_frame_length=source_frame_length, + logging.info(f"Processing frames: {len(frames)}") + if prev_frames: + img1 = prev_frames[-1] + img2 = frames[0] + img1_2 = self.interpolator.interpolate(img1, img2) + interpolated_frames.append(img1_2) + generator = self._frame_generator(prev_frames, interpolated_frames) + part_path = self.fs.video_part_path / f"video_{part:08d}.mp4" + self.video_maker.images_to_video_pipeline( + generator, part_path, width, height, fps ) + interpolated_frames = [] logging.info(f"Finished processing part {part:08d}") - frame_count += 1 part += 1 for i in tqdm.tqdm( - range(len(frame_paths) - 1), + range(len(frames) - 1), desc=f"Processing video frames {part + 1} / {total_parts}", ): - img1 = frame_paths[i] - img2 = frame_paths[i + 1] - output_path = self.fs.interpolated_path / f"img_{i:08d}.png" - self.interpolator.interpolate(img1, img2, output_path) - logging.debug(f"Interpolated image saved to: {output_path}") - frame_count += 1 - source_frame_length = len(frame_paths) - prev_frame_path = frame_paths + img1 = frames[i] + img2 = frames[i + 1] + img1_2 = self.interpolator.interpolate(img1, img2) + interpolated_frames.append(img1_2) + prev_frames = frames - self._merge_frames_to_video( - self.fs.video_part_path / f"video_{part:08d}.mp4", - fps, - source_frame_length=source_frame_length, + generator = self._frame_generator(prev_frames, interpolated_frames) + part_path = self.fs.video_part_path / f"video_{part:08d}.mp4" + self.video_maker.images_to_video_pipeline( + generator, part_path, width, height, fps ) logging.info(f"Finished processing part {part:08d}") self._merge_video_parts(self.fs.output_path / output_video) @@ -148,32 +146,40 @@ class InterpolationPipeline: f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" ) - def _merge_frames_to_video( - self, output_video: Path, fps: float, source_frame_length: int = 0 + def _save_images( + self, + source: tuple["np.ndarray", ...], + interpolated: list["np.ndarray"], ): - self._move_frames(source_frame_length) + logging.info("Saving images...") + self.fs.clear_directory(self.fs.moved_path) + index = 0 + for i, frame in enumerate(source): + name = self.fs.moved_path / f"img_{index:08d}.png" + index += 1 + imwrite(name, frame) + if i < len(interpolated): + name = self.fs.moved_path / f"img_{index:08d}.png" + index += 1 + imwrite(name, interpolated[i]) + logging.info("Success...") + + def _merge_frames_to_video(self, output_video: Path, fps: float): self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) def _merge_video_parts(self, output_video: Path): self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) self.fs.clear_directory(self.fs.video_part_path) - def _move_frames(self, source_frame_length: int = 0): - self.fs.clear_directory(self.fs.moved_path) - src_frames = sorted(self.fs.frames_path.glob("*.png")) - interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png")) - index = 0 - for i in range(source_frame_length): - moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png" - src_frames[i].rename(moved_frame_path) - index += 1 - if i < len(interpolated_frames): - moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png" - interpolated_frames[i].rename(moved_interpolated_path) - index += 1 - logging.info( - f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}" - ) + def _frame_generator( + self, + source: tuple["np.ndarray", ...], + interpolated: list["np.ndarray"], + ): + for i, frame in enumerate(source): + yield frame + if i < len(interpolated): + yield interpolated[i] def runner( @@ -220,7 +226,7 @@ def main(): base_path=Path(args.base_path), video_path=Path(args.video_path), output_video=args.output, - preset=getattr(presets, args.preset.upper()) + preset=getattr(presets, args.preset.upper()), ) diff --git a/src/interpolator.py b/src/interpolator.py index 41a5848..49d0b94 100644 --- a/src/interpolator.py +++ b/src/interpolator.py @@ -4,7 +4,6 @@ from pathlib import Path import torch import numpy as np from omegaconf import OmegaConf, DictConfig -from imageio import imread, imwrite from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img from src.utils.build import build_from_cfg @@ -83,7 +82,7 @@ class ImageInterpolator: f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes" ) - def interpolate(self, image1: Path, image2: Path, output_path: Path): + def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray: """ Interpolates between two images and saves the result. Args: @@ -92,8 +91,8 @@ class ImageInterpolator: output_path (Path): Path to save the interpolated image (only png and jpg formats are supported) """ logging.debug(f"Reading images: {image1} and {image2}") - tensor1 = img2tensor(imread(image1)).to(self.device) - tensor2 = img2tensor(imread(image2)).to(self.device) + tensor1 = img2tensor(image1).to(self.device) + tensor2 = img2tensor(image2).to(self.device) logging.debug( f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" ) @@ -122,8 +121,7 @@ class ImageInterpolator: logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") (interpolated,) = padder.unpad(interpolated) logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") - imwrite(output_path, tensor2img(interpolated.cpu())) - logging.debug(f"Saved interpolated image to: {output_path}") + return tensor2img(interpolated.cpu()) def scale(self, height: int, width: int) -> float: scale = ( diff --git a/src/utils/video.py b/src/utils/video.py index c9ab60d..9ca64bd 100644 --- a/src/utils/video.py +++ b/src/utils/video.py @@ -2,9 +2,10 @@ import os import logging import subprocess from pathlib import Path +from typing import Generator, Iterable import cv2 -from typing import Generator +import numpy as np class VideoMaker: @@ -35,7 +36,7 @@ class VideoMaker: with open(file, "w") as f: for video in videos: f.write(f"file '{video}'\n") - cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}" + cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}" logging.info(f"Running command: {cmd}") result = self.run_command(cmd) if result != 0: @@ -66,7 +67,13 @@ class VideoMaker: def run_command(self, cmd: str) -> int: try: - subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run( + cmd, + shell=True, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) return 0 except subprocess.CalledProcessError as e: logging.error(f"Command failed with error: {e}") @@ -74,7 +81,7 @@ class VideoMaker: def video_to_frames_generator( self, video_path: Path, output_dir: Path, chunk_seconds: int = 10 - ) -> Generator[tuple[Path, ...], None, None]: + ) -> Generator[tuple[np.ndarray, ...], None, None]: """Extracts frames from a video and saves them to disk, yielding paths to the saved frames.""" cap = cv2.VideoCapture(str(video_path)) @@ -85,21 +92,56 @@ class VideoMaker: fps = cap.get(cv2.CAP_PROP_FPS) frames_per_chunk = int(fps * chunk_seconds) - frame_index = 0 - while True: paths = [] - for _ in range(frames_per_chunk): ret, frame = cap.read() if not ret: cap.release() return - - frame_path = output_dir / f"img_{frame_index:08d}.png" - cv2.imwrite(str(frame_path), frame) - - paths.append(frame_path) - frame_index += 1 - + paths.append(frame) yield tuple(paths) + + def images_to_video_pipeline( + self, + frames: Iterable[np.ndarray], + output_path: Path, + width: int, + height: int, + fps: float, + ): + pipeline = subprocess.Popen( + [ + "ffmpeg", + "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", "bgr24", + "-s", f"{width}x{height}", + "-r", str(fps), + "-i", "-", + "-an", + "-vcodec", "libx264", + "-pix_fmt", "yuv420p", + str(output_path), + ], + stdin=subprocess.PIPE, + stderr=subprocess.DEVNULL + ) + if pipeline.stdin is None: + raise Exception("STDIN closed") + for frame in frames: + pipeline.stdin.write(frame.tobytes()) + + pipeline.stdin.close() + pipeline.wait() + + def get_size(self, video_path: Path) -> tuple[int, int]: + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + cap.release() + return width, height