dev #2

Merged
lovinervy merged 2 commits from dev into main 2026-04-03 18:28:31 +05:00
4 changed files with 56 additions and 62 deletions
Showing only changes of commit faf7aa8e81 - Show all commits

88
main.py
View File

@@ -2,6 +2,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm import tqdm
from src.config import presets from src.config import presets
@@ -18,6 +19,7 @@ from src.interpolator import (
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
import numpy as np
def performing_warning_message(device: "torch.device"): def performing_warning_message(device: "torch.device"):
@@ -53,7 +55,7 @@ def init_device() -> "torch.device":
device = get_device() device = get_device()
performing_warning_message(device) performing_warning_message(device)
vram_available = get_vram_available(device) vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB") logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device return device
@@ -95,8 +97,8 @@ class InterpolationPipeline:
self.interpolator = init_interpolator(self.model_runner, self.device) self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str): def run(self, video_path: Path, output_video: str):
prev_frame_path = None prev_frames = tuple()
frame_count = 0 interpolated_frames = []
part = 0 part = 0
source_frame_length = 0 source_frame_length = 0
chunk_seconds = 10 chunk_seconds = 10
@@ -106,75 +108,77 @@ class InterpolationPipeline:
fps = self.video_maker.get_fps(video_path) fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}") logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS fps *= 2 # Doubling FPS
for frame_paths in self.video_maker.video_to_frames_generator( for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds video_path, self.fs.frames_path, chunk_seconds
): ):
logging.info(f"Processing frames: {len(frame_paths)}") logging.info(f"Processing frames: {len(frames)}")
if prev_frame_path is not None: if prev_frames:
img1 = prev_frame_path[-1] img1 = prev_frames[-1]
img2 = frame_paths[0] img2 = frames[0]
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png" img1_2 = self.interpolator.interpolate(img1, img2)
self.interpolator.interpolate(img1, img2, output_path) interpolated_frames.append(img1_2)
logging.debug(f"Interpolated image saved to: {output_path}") self.fs.clear_directory(self.fs.moved_path)
self._save_images(prev_frames, interpolated_frames)
self._merge_frames_to_video( self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4", self.fs.video_part_path / f"video_{part:08d}.mp4",
fps, fps,
source_frame_length=source_frame_length, source_frame_length,
) )
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}") logging.info(f"Finished processing part {part:08d}")
frame_count += 1
part += 1 part += 1
for i in tqdm.tqdm( for i in tqdm.tqdm(
range(len(frame_paths) - 1), range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}", desc=f"Processing video frames {part + 1} / {total_parts}",
): ):
img1 = frame_paths[i] img1 = frames[i]
img2 = frame_paths[i + 1] img2 = frames[i + 1]
output_path = self.fs.interpolated_path / f"img_{i:08d}.png" img1_2 = self.interpolator.interpolate(img1, img2)
self.interpolator.interpolate(img1, img2, output_path) interpolated_frames.append(img1_2)
logging.debug(f"Interpolated image saved to: {output_path}") source_frame_length = len(frames)
frame_count += 1 prev_frames = frames
source_frame_length = len(frame_paths)
prev_frame_path = frame_paths
self.fs.clear_directory(self.fs.moved_path)
self._save_images(prev_frames, interpolated_frames)
self._merge_frames_to_video( self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4", self.fs.video_part_path / f"video_{part:08d}.mp4",
fps, fps,
source_frame_length=source_frame_length, source_frame_length,
) )
self.fs.clear_directory(self.fs.moved_path)
logging.info(f"Finished processing part {part:08d}") logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video) self._merge_video_parts(self.fs.output_path / output_video)
logging.info( logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
) )
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video( def _merge_frames_to_video(
self, output_video: Path, fps: float, source_frame_length: int = 0 self, output_video: Path, fps: float, source_frame_length: int = 0
): ):
self._move_frames(source_frame_length)
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path): def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path) self.fs.clear_directory(self.fs.video_part_path)
def _move_frames(self, source_frame_length: int = 0):
self.fs.clear_directory(self.fs.moved_path)
src_frames = sorted(self.fs.frames_path.glob("*.png"))
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
index = 0
for i in range(source_frame_length):
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
src_frames[i].rename(moved_frame_path)
index += 1
if i < len(interpolated_frames):
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
interpolated_frames[i].rename(moved_interpolated_path)
index += 1
logging.info(
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
)
def runner( def runner(
base_path: Path, base_path: Path,
@@ -220,7 +224,7 @@ def main():
base_path=Path(args.base_path), base_path=Path(args.base_path),
video_path=Path(args.video_path), video_path=Path(args.video_path),
output_video=args.output, output_video=args.output,
preset=getattr(presets, args.preset.upper()) preset=getattr(presets, args.preset.upper()),
) )

View File

@@ -19,6 +19,6 @@ LARGE = Preset(
) )
GLOBAL = Preset( GLOBAL = Preset(
config=Path("src/config/AMT-g.yaml"), config=Path("src/config/AMT-G.yaml"),
checkpoint=Path("src/pretrained/amt-g.pth"), checkpoint=Path("src/pretrained/amt-g.pth"),
) )

View File

@@ -4,7 +4,6 @@ from pathlib import Path
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
from imageio import imread, imwrite
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg from src.utils.build import build_from_cfg
@@ -83,7 +82,7 @@ class ImageInterpolator:
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes" f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
) )
def interpolate(self, image1: Path, image2: Path, output_path: Path): def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
""" """
Interpolates between two images and saves the result. Interpolates between two images and saves the result.
Args: Args:
@@ -92,8 +91,8 @@ class ImageInterpolator:
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported) output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
""" """
logging.debug(f"Reading images: {image1} and {image2}") logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(imread(image1)).to(self.device) tensor1 = img2tensor(image1).to(self.device)
tensor2 = img2tensor(imread(image2)).to(self.device) tensor2 = img2tensor(image2).to(self.device)
logging.debug( logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
) )
@@ -122,8 +121,7 @@ class ImageInterpolator:
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated) (interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
imwrite(output_path, tensor2img(interpolated.cpu())) return tensor2img(interpolated.cpu())
logging.debug(f"Saved interpolated image to: {output_path}")
def scale(self, height: int, width: int) -> float: def scale(self, height: int, width: int) -> float:
scale = ( scale = (

View File

@@ -2,9 +2,10 @@ import os
import logging import logging
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Generator
import cv2 import cv2
from typing import Generator import numpy as np
class VideoMaker: class VideoMaker:
@@ -35,7 +36,7 @@ class VideoMaker:
with open(file, "w") as f: with open(file, "w") as f:
for video in videos: for video in videos:
f.write(f"file '{video}'\n") f.write(f"file '{video}'\n")
cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}" cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
logging.info(f"Running command: {cmd}") logging.info(f"Running command: {cmd}")
result = self.run_command(cmd) result = self.run_command(cmd)
if result != 0: if result != 0:
@@ -74,7 +75,7 @@ class VideoMaker:
def video_to_frames_generator( def video_to_frames_generator(
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10 self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
) -> Generator[tuple[Path, ...], None, None]: ) -> Generator[tuple[np.ndarray, ...], None, None]:
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames.""" """Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
cap = cv2.VideoCapture(str(video_path)) cap = cv2.VideoCapture(str(video_path))
@@ -85,21 +86,12 @@ class VideoMaker:
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
frames_per_chunk = int(fps * chunk_seconds) frames_per_chunk = int(fps * chunk_seconds)
frame_index = 0
while True: while True:
paths = [] paths = []
for _ in range(frames_per_chunk): for _ in range(frames_per_chunk):
ret, frame = cap.read() ret, frame = cap.read()
if not ret: if not ret:
cap.release() cap.release()
return return
paths.append(frame)
frame_path = output_dir / f"img_{frame_index:08d}.png"
cv2.imwrite(str(frame_path), frame)
paths.append(frame_path)
frame_index += 1
yield tuple(paths) yield tuple(paths)