Начал менять логику pipeline
This commit is contained in:
88
main.py
88
main.py
@@ -2,6 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cv2 import imwrite
|
||||
import tqdm
|
||||
|
||||
from src.config import presets
|
||||
@@ -18,6 +19,7 @@ from src.interpolator import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def performing_warning_message(device: "torch.device"):
|
||||
@@ -53,7 +55,7 @@ def init_device() -> "torch.device":
|
||||
device = get_device()
|
||||
performing_warning_message(device)
|
||||
vram_available = get_vram_available(device)
|
||||
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
|
||||
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
||||
return device
|
||||
|
||||
|
||||
@@ -95,8 +97,8 @@ class InterpolationPipeline:
|
||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||
|
||||
def run(self, video_path: Path, output_video: str):
|
||||
prev_frame_path = None
|
||||
frame_count = 0
|
||||
prev_frames = tuple()
|
||||
interpolated_frames = []
|
||||
part = 0
|
||||
source_frame_length = 0
|
||||
chunk_seconds = 10
|
||||
@@ -106,75 +108,77 @@ class InterpolationPipeline:
|
||||
fps = self.video_maker.get_fps(video_path)
|
||||
logging.info(f"Video FPS: {fps}")
|
||||
fps *= 2 # Doubling FPS
|
||||
for frame_paths in self.video_maker.video_to_frames_generator(
|
||||
for frames in self.video_maker.video_to_frames_generator(
|
||||
video_path, self.fs.frames_path, chunk_seconds
|
||||
):
|
||||
logging.info(f"Processing frames: {len(frame_paths)}")
|
||||
if prev_frame_path is not None:
|
||||
img1 = prev_frame_path[-1]
|
||||
img2 = frame_paths[0]
|
||||
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
|
||||
self.interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
logging.info(f"Processing frames: {len(frames)}")
|
||||
if prev_frames:
|
||||
img1 = prev_frames[-1]
|
||||
img2 = frames[0]
|
||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||
interpolated_frames.append(img1_2)
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
self._save_images(prev_frames, interpolated_frames)
|
||||
|
||||
self._merge_frames_to_video(
|
||||
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
||||
fps,
|
||||
source_frame_length=source_frame_length,
|
||||
source_frame_length,
|
||||
)
|
||||
interpolated_frames = []
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
frame_count += 1
|
||||
part += 1
|
||||
for i in tqdm.tqdm(
|
||||
range(len(frame_paths) - 1),
|
||||
range(len(frames) - 1),
|
||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||
):
|
||||
img1 = frame_paths[i]
|
||||
img2 = frame_paths[i + 1]
|
||||
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
|
||||
self.interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
frame_count += 1
|
||||
source_frame_length = len(frame_paths)
|
||||
prev_frame_path = frame_paths
|
||||
img1 = frames[i]
|
||||
img2 = frames[i + 1]
|
||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||
interpolated_frames.append(img1_2)
|
||||
source_frame_length = len(frames)
|
||||
prev_frames = frames
|
||||
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
self._save_images(prev_frames, interpolated_frames)
|
||||
self._merge_frames_to_video(
|
||||
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
||||
fps,
|
||||
source_frame_length=source_frame_length,
|
||||
source_frame_length,
|
||||
)
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
self._merge_video_parts(self.fs.output_path / output_video)
|
||||
logging.info(
|
||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||
)
|
||||
|
||||
def _save_images(
|
||||
self,
|
||||
source: tuple["np.ndarray", ...],
|
||||
interpolated: list["np.ndarray"],
|
||||
):
|
||||
logging.info("Saving images...")
|
||||
index = 0
|
||||
for i, frame in enumerate(source):
|
||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
index += 1
|
||||
imwrite(name, frame)
|
||||
if i < len(interpolated):
|
||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
index += 1
|
||||
imwrite(name, interpolated[i])
|
||||
logging.info("Success...")
|
||||
|
||||
def _merge_frames_to_video(
|
||||
self, output_video: Path, fps: float, source_frame_length: int = 0
|
||||
):
|
||||
self._move_frames(source_frame_length)
|
||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||
|
||||
def _merge_video_parts(self, output_video: Path):
|
||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||
self.fs.clear_directory(self.fs.video_part_path)
|
||||
|
||||
def _move_frames(self, source_frame_length: int = 0):
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
src_frames = sorted(self.fs.frames_path.glob("*.png"))
|
||||
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
|
||||
index = 0
|
||||
for i in range(source_frame_length):
|
||||
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
src_frames[i].rename(moved_frame_path)
|
||||
index += 1
|
||||
if i < len(interpolated_frames):
|
||||
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
interpolated_frames[i].rename(moved_interpolated_path)
|
||||
index += 1
|
||||
logging.info(
|
||||
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
|
||||
)
|
||||
|
||||
|
||||
def runner(
|
||||
base_path: Path,
|
||||
@@ -220,7 +224,7 @@ def main():
|
||||
base_path=Path(args.base_path),
|
||||
video_path=Path(args.video_path),
|
||||
output_video=args.output,
|
||||
preset=getattr(presets, args.preset.upper())
|
||||
preset=getattr(presets, args.preset.upper()),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user