3 Commits

Author SHA1 Message Date
Viner Abubakirov
0c871c2314 Refactor image tensor conversion and model inference in interpolator.py and torch.py 2026-04-13 18:56:00 +05:00
Viner Abubakirov
61f8e0abe1 Поменял, метод формирования видео 2026-04-03 17:03:10 +05:00
Viner Abubakirov
faf7aa8e81 Начал менять логику pipeline 2026-04-02 18:26:36 +05:00
5 changed files with 136 additions and 102 deletions

106
main.py
View File

@@ -2,6 +2,7 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
from src.config import presets
@@ -18,6 +19,7 @@ from src.interpolator import (
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
@@ -53,7 +55,7 @@ def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
@@ -95,10 +97,9 @@ class InterpolationPipeline:
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frame_path = None
frame_count = 0
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
source_frame_length = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
@@ -106,41 +107,38 @@ class InterpolationPipeline:
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
for frame_paths in self.video_maker.video_to_frames_generator(
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frame_paths)}")
if prev_frame_path is not None:
img1 = prev_frame_path[-1]
img2 = frame_paths[0]
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
self.interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4",
fps,
source_frame_length=source_frame_length,
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
frame_count += 1
part += 1
for i in tqdm.tqdm(
range(len(frame_paths) - 1),
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frame_paths[i]
img2 = frame_paths[i + 1]
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
self.interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
frame_count += 1
source_frame_length = len(frame_paths)
prev_frame_path = frame_paths
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4",
fps,
source_frame_length=source_frame_length,
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
@@ -148,32 +146,40 @@ class InterpolationPipeline:
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _merge_frames_to_video(
self, output_video: Path, fps: float, source_frame_length: int = 0
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
self._move_frames(source_frame_length)
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _move_frames(self, source_frame_length: int = 0):
self.fs.clear_directory(self.fs.moved_path)
src_frames = sorted(self.fs.frames_path.glob("*.png"))
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
index = 0
for i in range(source_frame_length):
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
src_frames[i].rename(moved_frame_path)
index += 1
if i < len(interpolated_frames):
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
interpolated_frames[i].rename(moved_interpolated_path)
index += 1
logging.info(
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
@@ -220,7 +226,7 @@ def main():
base_path=Path(args.base_path),
video_path=Path(args.video_path),
output_video=args.output,
preset=getattr(presets, args.preset.upper())
preset=getattr(presets, args.preset.upper()),
)

View File

@@ -19,6 +19,6 @@ LARGE = Preset(
)
GLOBAL = Preset(
config=Path("src/config/AMT-g.yaml"),
config=Path("src/config/AMT-G.yaml"),
checkpoint=Path("src/pretrained/amt-g.pth"),
)

View File

@@ -1,12 +1,12 @@
import logging
from pathlib import Path
from typing import Optional
import torch
import numpy as np
from omegaconf import OmegaConf, DictConfig
from imageio import imread, imwrite
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.torch import img2tensor, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
@@ -40,7 +40,7 @@ class ModelRunner:
model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device())
model.eval()
self.model = model
self.model = torch.compile(model)
def get_vram_available(device: torch.device) -> int:
@@ -83,7 +83,7 @@ class ImageInterpolator:
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
)
def interpolate(self, image1: Path, image2: Path, output_path: Path):
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""
Interpolates between two images and saves the result.
Args:
@@ -92,38 +92,21 @@ class ImageInterpolator:
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(imread(image1)).to(self.device)
tensor2 = img2tensor(imread(image2)).to(self.device)
tensor1 = img2tensor(image1, self.device)
tensor2 = img2tensor(image2, self.device)
logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
)
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation")
with torch.no_grad():
interpolated = self.model_runner.model(
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
)["imgt_pred"]
with torch.amp.autocast(self.device.type):
interpolated = self.model_runner.model(
tensor1, tensor2, self.embt
)["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
imwrite(output_path, tensor2img(interpolated.cpu()))
logging.debug(f"Saved interpolated image to: {output_path}")
return tensor2img(interpolated.cpu())
def scale(self, height: int, width: int) -> float:
scale = (

View File

@@ -5,23 +5,26 @@ import numpy as np
def tensor2img(tensor: torch.Tensor):
return (
(tensor * 255.0)
.detach()
tensor = (
tensor.mul(255.0)
.clamp_(0, 255)
.to(torch.uint8)
.squeeze(0)
.permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3:
img = img[:, :, :3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:

View File

@@ -2,9 +2,10 @@ import os
import logging
import subprocess
from pathlib import Path
from typing import Generator, Iterable
import cv2
from typing import Generator
import numpy as np
class VideoMaker:
@@ -35,7 +36,7 @@ class VideoMaker:
with open(file, "w") as f:
for video in videos:
f.write(f"file '{video}'\n")
cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}"
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
logging.info(f"Running command: {cmd}")
result = self.run_command(cmd)
if result != 0:
@@ -66,7 +67,13 @@ class VideoMaker:
def run_command(self, cmd: str) -> int:
try:
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
subprocess.run(
cmd,
shell=True,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return 0
except subprocess.CalledProcessError as e:
logging.error(f"Command failed with error: {e}")
@@ -74,7 +81,7 @@ class VideoMaker:
def video_to_frames_generator(
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
) -> Generator[tuple[Path, ...], None, None]:
) -> Generator[tuple[np.ndarray, ...], None, None]:
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
cap = cv2.VideoCapture(str(video_path))
@@ -85,21 +92,56 @@ class VideoMaker:
fps = cap.get(cv2.CAP_PROP_FPS)
frames_per_chunk = int(fps * chunk_seconds)
frame_index = 0
while True:
paths = []
for _ in range(frames_per_chunk):
ret, frame = cap.read()
if not ret:
cap.release()
return
frame_path = output_dir / f"img_{frame_index:08d}.png"
cv2.imwrite(str(frame_path), frame)
paths.append(frame_path)
frame_index += 1
paths.append(frame)
yield tuple(paths)
def images_to_video_pipeline(
self,
frames: Iterable[np.ndarray],
output_path: Path,
width: int,
height: int,
fps: float,
):
pipeline = subprocess.Popen(
[
"ffmpeg",
"-y",
"-f", "rawvideo",
"-vcodec", "rawvideo",
"-pix_fmt", "bgr24",
"-s", f"{width}x{height}",
"-r", str(fps),
"-i", "-",
"-an",
"-vcodec", "libx264",
"-pix_fmt", "yuv420p",
str(output_path),
],
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL
)
if pipeline.stdin is None:
raise Exception("STDIN closed")
for frame in frames:
pipeline.stdin.write(frame.tobytes())
pipeline.stdin.close()
pipeline.wait()
def get_size(self, video_path: Path) -> tuple[int, int]:
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
return width, height