Compare commits
3 Commits
c72e34f9dc
...
dev-cuda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c871c2314 | ||
|
|
61f8e0abe1 | ||
|
|
faf7aa8e81 |
106
main.py
106
main.py
@@ -2,6 +2,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cv2 import imwrite
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from src.config import presets
|
from src.config import presets
|
||||||
@@ -18,6 +19,7 @@ from src.interpolator import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def performing_warning_message(device: "torch.device"):
|
def performing_warning_message(device: "torch.device"):
|
||||||
@@ -53,7 +55,7 @@ def init_device() -> "torch.device":
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
performing_warning_message(device)
|
performing_warning_message(device)
|
||||||
vram_available = get_vram_available(device)
|
vram_available = get_vram_available(device)
|
||||||
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
|
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
@@ -95,10 +97,9 @@ class InterpolationPipeline:
|
|||||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||||
|
|
||||||
def run(self, video_path: Path, output_video: str):
|
def run(self, video_path: Path, output_video: str):
|
||||||
prev_frame_path = None
|
prev_frames = tuple()
|
||||||
frame_count = 0
|
interpolated_frames: list["np.ndarray"] = []
|
||||||
part = 0
|
part = 0
|
||||||
source_frame_length = 0
|
|
||||||
chunk_seconds = 10
|
chunk_seconds = 10
|
||||||
length = self.video_maker.get_video_duration(video_path)
|
length = self.video_maker.get_video_duration(video_path)
|
||||||
last_part_seconds = 1 if length % chunk_seconds else 0
|
last_part_seconds = 1 if length % chunk_seconds else 0
|
||||||
@@ -106,41 +107,38 @@ class InterpolationPipeline:
|
|||||||
fps = self.video_maker.get_fps(video_path)
|
fps = self.video_maker.get_fps(video_path)
|
||||||
logging.info(f"Video FPS: {fps}")
|
logging.info(f"Video FPS: {fps}")
|
||||||
fps *= 2 # Doubling FPS
|
fps *= 2 # Doubling FPS
|
||||||
for frame_paths in self.video_maker.video_to_frames_generator(
|
width, height = self.video_maker.get_size(video_path)
|
||||||
|
for frames in self.video_maker.video_to_frames_generator(
|
||||||
video_path, self.fs.frames_path, chunk_seconds
|
video_path, self.fs.frames_path, chunk_seconds
|
||||||
):
|
):
|
||||||
logging.info(f"Processing frames: {len(frame_paths)}")
|
logging.info(f"Processing frames: {len(frames)}")
|
||||||
if prev_frame_path is not None:
|
if prev_frames:
|
||||||
img1 = prev_frame_path[-1]
|
img1 = prev_frames[-1]
|
||||||
img2 = frame_paths[0]
|
img2 = frames[0]
|
||||||
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
|
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||||
self.interpolator.interpolate(img1, img2, output_path)
|
interpolated_frames.append(img1_2)
|
||||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||||
self._merge_frames_to_video(
|
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||||
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
self.video_maker.images_to_video_pipeline(
|
||||||
fps,
|
generator, part_path, width, height, fps
|
||||||
source_frame_length=source_frame_length,
|
|
||||||
)
|
)
|
||||||
|
interpolated_frames = []
|
||||||
logging.info(f"Finished processing part {part:08d}")
|
logging.info(f"Finished processing part {part:08d}")
|
||||||
frame_count += 1
|
|
||||||
part += 1
|
part += 1
|
||||||
for i in tqdm.tqdm(
|
for i in tqdm.tqdm(
|
||||||
range(len(frame_paths) - 1),
|
range(len(frames) - 1),
|
||||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||||
):
|
):
|
||||||
img1 = frame_paths[i]
|
img1 = frames[i]
|
||||||
img2 = frame_paths[i + 1]
|
img2 = frames[i + 1]
|
||||||
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
|
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||||
self.interpolator.interpolate(img1, img2, output_path)
|
interpolated_frames.append(img1_2)
|
||||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
prev_frames = frames
|
||||||
frame_count += 1
|
|
||||||
source_frame_length = len(frame_paths)
|
|
||||||
prev_frame_path = frame_paths
|
|
||||||
|
|
||||||
self._merge_frames_to_video(
|
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||||
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||||
fps,
|
self.video_maker.images_to_video_pipeline(
|
||||||
source_frame_length=source_frame_length,
|
generator, part_path, width, height, fps
|
||||||
)
|
)
|
||||||
logging.info(f"Finished processing part {part:08d}")
|
logging.info(f"Finished processing part {part:08d}")
|
||||||
self._merge_video_parts(self.fs.output_path / output_video)
|
self._merge_video_parts(self.fs.output_path / output_video)
|
||||||
@@ -148,32 +146,40 @@ class InterpolationPipeline:
|
|||||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _merge_frames_to_video(
|
def _save_images(
|
||||||
self, output_video: Path, fps: float, source_frame_length: int = 0
|
self,
|
||||||
|
source: tuple["np.ndarray", ...],
|
||||||
|
interpolated: list["np.ndarray"],
|
||||||
):
|
):
|
||||||
self._move_frames(source_frame_length)
|
logging.info("Saving images...")
|
||||||
|
self.fs.clear_directory(self.fs.moved_path)
|
||||||
|
index = 0
|
||||||
|
for i, frame in enumerate(source):
|
||||||
|
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||||
|
index += 1
|
||||||
|
imwrite(name, frame)
|
||||||
|
if i < len(interpolated):
|
||||||
|
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||||
|
index += 1
|
||||||
|
imwrite(name, interpolated[i])
|
||||||
|
logging.info("Success...")
|
||||||
|
|
||||||
|
def _merge_frames_to_video(self, output_video: Path, fps: float):
|
||||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||||
|
|
||||||
def _merge_video_parts(self, output_video: Path):
|
def _merge_video_parts(self, output_video: Path):
|
||||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||||
self.fs.clear_directory(self.fs.video_part_path)
|
self.fs.clear_directory(self.fs.video_part_path)
|
||||||
|
|
||||||
def _move_frames(self, source_frame_length: int = 0):
|
def _frame_generator(
|
||||||
self.fs.clear_directory(self.fs.moved_path)
|
self,
|
||||||
src_frames = sorted(self.fs.frames_path.glob("*.png"))
|
source: tuple["np.ndarray", ...],
|
||||||
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
|
interpolated: list["np.ndarray"],
|
||||||
index = 0
|
):
|
||||||
for i in range(source_frame_length):
|
for i, frame in enumerate(source):
|
||||||
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
|
yield frame
|
||||||
src_frames[i].rename(moved_frame_path)
|
if i < len(interpolated):
|
||||||
index += 1
|
yield interpolated[i]
|
||||||
if i < len(interpolated_frames):
|
|
||||||
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
|
|
||||||
interpolated_frames[i].rename(moved_interpolated_path)
|
|
||||||
index += 1
|
|
||||||
logging.info(
|
|
||||||
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def runner(
|
def runner(
|
||||||
@@ -220,7 +226,7 @@ def main():
|
|||||||
base_path=Path(args.base_path),
|
base_path=Path(args.base_path),
|
||||||
video_path=Path(args.video_path),
|
video_path=Path(args.video_path),
|
||||||
output_video=args.output,
|
output_video=args.output,
|
||||||
preset=getattr(presets, args.preset.upper())
|
preset=getattr(presets, args.preset.upper()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
from imageio import imread, imwrite
|
|
||||||
|
|
||||||
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
from src.utils.torch import img2tensor, tensor2img
|
||||||
from src.utils.build import build_from_cfg
|
from src.utils.build import build_from_cfg
|
||||||
from src.utils.padder import InputPadder
|
from src.utils.padder import InputPadder
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ class ModelRunner:
|
|||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
model = model.to(get_device())
|
model = model.to(get_device())
|
||||||
model.eval()
|
model.eval()
|
||||||
self.model = model
|
self.model = torch.compile(model)
|
||||||
|
|
||||||
|
|
||||||
def get_vram_available(device: torch.device) -> int:
|
def get_vram_available(device: torch.device) -> int:
|
||||||
@@ -83,7 +83,7 @@ class ImageInterpolator:
|
|||||||
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
def interpolate(self, image1: Path, image2: Path, output_path: Path):
|
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Interpolates between two images and saves the result.
|
Interpolates between two images and saves the result.
|
||||||
Args:
|
Args:
|
||||||
@@ -92,38 +92,21 @@ class ImageInterpolator:
|
|||||||
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
||||||
"""
|
"""
|
||||||
logging.debug(f"Reading images: {image1} and {image2}")
|
logging.debug(f"Reading images: {image1} and {image2}")
|
||||||
tensor1 = img2tensor(imread(image1)).to(self.device)
|
tensor1 = img2tensor(image1, self.device)
|
||||||
tensor2 = img2tensor(imread(image2)).to(self.device)
|
tensor2 = img2tensor(image2, self.device)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||||
)
|
)
|
||||||
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
|
||||||
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
|
||||||
h, w = tensor1.shape[2], tensor1.shape[3]
|
|
||||||
logging.debug(f"Interpolating images of size: {h}x{w}")
|
|
||||||
|
|
||||||
scale = self.scale(h, w)
|
|
||||||
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
|
||||||
padding = int(16 / scale)
|
|
||||||
logging.debug(f"Calculated padding: {padding} pixels")
|
|
||||||
padder = InputPadder(tensor1.shape, divisor=padding)
|
|
||||||
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
|
||||||
logging.debug(
|
|
||||||
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
tensor1_padded = tensor1_padded.to(self.device)
|
|
||||||
tensor2_padded = tensor2_padded.to(self.device)
|
|
||||||
logging.debug("Running model inference for interpolation")
|
logging.debug("Running model inference for interpolation")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
with torch.amp.autocast(self.device.type):
|
||||||
interpolated = self.model_runner.model(
|
interpolated = self.model_runner.model(
|
||||||
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
tensor1, tensor2, self.embt
|
||||||
)["imgt_pred"]
|
)["imgt_pred"]
|
||||||
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||||
(interpolated,) = padder.unpad(interpolated)
|
|
||||||
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||||
imwrite(output_path, tensor2img(interpolated.cpu()))
|
return tensor2img(interpolated.cpu())
|
||||||
logging.debug(f"Saved interpolated image to: {output_path}")
|
|
||||||
|
|
||||||
def scale(self, height: int, width: int) -> float:
|
def scale(self, height: int, width: int) -> float:
|
||||||
scale = (
|
scale = (
|
||||||
|
|||||||
@@ -5,23 +5,26 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def tensor2img(tensor: torch.Tensor):
|
def tensor2img(tensor: torch.Tensor):
|
||||||
return (
|
tensor = (
|
||||||
(tensor * 255.0)
|
tensor.mul(255.0)
|
||||||
.detach()
|
.clamp_(0, 255)
|
||||||
|
.to(torch.uint8)
|
||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
.permute(1, 2, 0)
|
.permute(1, 2, 0)
|
||||||
.cpu()
|
|
||||||
.numpy()
|
|
||||||
.clip(0, 255)
|
|
||||||
.astype(np.uint8)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return tensor.cpu().numpy()
|
||||||
|
|
||||||
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
|
||||||
|
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
|
||||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||||
if img.shape[-1] > 3:
|
if img.shape[-1] > 3:
|
||||||
img = img[:, :, :3]
|
img = img[:, :, :3]
|
||||||
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
if device.type != "cuda":
|
||||||
|
return tensor.float() / 255.0
|
||||||
|
|
||||||
|
return tensor.cuda(non_blocking=True).float().div_(255.0)
|
||||||
|
|
||||||
|
|
||||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator, Iterable
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from typing import Generator
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class VideoMaker:
|
class VideoMaker:
|
||||||
@@ -35,7 +36,7 @@ class VideoMaker:
|
|||||||
with open(file, "w") as f:
|
with open(file, "w") as f:
|
||||||
for video in videos:
|
for video in videos:
|
||||||
f.write(f"file '{video}'\n")
|
f.write(f"file '{video}'\n")
|
||||||
cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}"
|
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
|
||||||
logging.info(f"Running command: {cmd}")
|
logging.info(f"Running command: {cmd}")
|
||||||
result = self.run_command(cmd)
|
result = self.run_command(cmd)
|
||||||
if result != 0:
|
if result != 0:
|
||||||
@@ -66,7 +67,13 @@ class VideoMaker:
|
|||||||
|
|
||||||
def run_command(self, cmd: str) -> int:
|
def run_command(self, cmd: str) -> int:
|
||||||
try:
|
try:
|
||||||
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
subprocess.run(
|
||||||
|
cmd,
|
||||||
|
shell=True,
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
return 0
|
return 0
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logging.error(f"Command failed with error: {e}")
|
logging.error(f"Command failed with error: {e}")
|
||||||
@@ -74,7 +81,7 @@ class VideoMaker:
|
|||||||
|
|
||||||
def video_to_frames_generator(
|
def video_to_frames_generator(
|
||||||
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
|
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
|
||||||
) -> Generator[tuple[Path, ...], None, None]:
|
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
||||||
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
|
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
|
||||||
|
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
@@ -85,21 +92,56 @@ class VideoMaker:
|
|||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
frames_per_chunk = int(fps * chunk_seconds)
|
frames_per_chunk = int(fps * chunk_seconds)
|
||||||
|
|
||||||
frame_index = 0
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
paths = []
|
paths = []
|
||||||
|
|
||||||
for _ in range(frames_per_chunk):
|
for _ in range(frames_per_chunk):
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
cap.release()
|
cap.release()
|
||||||
return
|
return
|
||||||
|
paths.append(frame)
|
||||||
frame_path = output_dir / f"img_{frame_index:08d}.png"
|
|
||||||
cv2.imwrite(str(frame_path), frame)
|
|
||||||
|
|
||||||
paths.append(frame_path)
|
|
||||||
frame_index += 1
|
|
||||||
|
|
||||||
yield tuple(paths)
|
yield tuple(paths)
|
||||||
|
|
||||||
|
def images_to_video_pipeline(
|
||||||
|
self,
|
||||||
|
frames: Iterable[np.ndarray],
|
||||||
|
output_path: Path,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
fps: float,
|
||||||
|
):
|
||||||
|
pipeline = subprocess.Popen(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-f", "rawvideo",
|
||||||
|
"-vcodec", "rawvideo",
|
||||||
|
"-pix_fmt", "bgr24",
|
||||||
|
"-s", f"{width}x{height}",
|
||||||
|
"-r", str(fps),
|
||||||
|
"-i", "-",
|
||||||
|
"-an",
|
||||||
|
"-vcodec", "libx264",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
str(output_path),
|
||||||
|
],
|
||||||
|
stdin=subprocess.PIPE,
|
||||||
|
stderr=subprocess.DEVNULL
|
||||||
|
)
|
||||||
|
if pipeline.stdin is None:
|
||||||
|
raise Exception("STDIN closed")
|
||||||
|
for frame in frames:
|
||||||
|
pipeline.stdin.write(frame.tobytes())
|
||||||
|
|
||||||
|
pipeline.stdin.close()
|
||||||
|
pipeline.wait()
|
||||||
|
|
||||||
|
def get_size(self, video_path: Path) -> tuple[int, int]:
|
||||||
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Cannot open video: {video_path}")
|
||||||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
return width, height
|
||||||
|
|||||||
Reference in New Issue
Block a user