немного переработал черновик для работы с видео
This commit is contained in:
274
main.py
274
main.py
@@ -1,153 +1,146 @@
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from time import perf_counter
|
||||
from decimal import Decimal
|
||||
|
||||
from src.utils import utils
|
||||
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||
from src.utils.build import build_from_cfg
|
||||
from src.utils.padder import InputPadder
|
||||
from interpolator import get_device
|
||||
from interpolator import ImageInterpolator
|
||||
from interpolator import ModelRunner, Anchor
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class Anchor:
|
||||
def __init__(self, resolution: int, memory: int, memory_bias: int) -> None:
|
||||
self.resolution = resolution
|
||||
self.memory = memory
|
||||
self.memory_bias = memory_bias
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
|
||||
"""Initializes the ModelRunner with configuration and checkpoint.
|
||||
def move_images(src_dir: str, interpolated_dir: str, output_dir: str):
|
||||
src_dir = Path(src_dir)
|
||||
interpolated_dir = Path(interpolated_dir)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Args:
|
||||
config (Path): Path to model configuration in YAML format
|
||||
ckpt_path (Path): Path to model checkpoint in .pth format
|
||||
device (torch.device): Device to load the model on
|
||||
"""
|
||||
omega_config = OmegaConf.load(config)
|
||||
network_config: DictConfig = omega_config.network
|
||||
logging.info(
|
||||
f"Loaded network configuration: {network_config} from [{ckpt_path}]"
|
||||
)
|
||||
model = build_from_cfg(network_config)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
model = model.to(get_device())
|
||||
model.eval()
|
||||
self.model = model
|
||||
index = 0
|
||||
src_frames = sorted(src_dir.glob("img_*.png"))
|
||||
interp_frames = sorted(interpolated_dir.glob("img_*.png"))
|
||||
for i in range(len(src_frames)):
|
||||
output_frame = output_dir / f"img_{index:08d}.png"
|
||||
src_frames[i].rename(output_frame)
|
||||
index += 1
|
||||
|
||||
if i < len(interp_frames):
|
||||
output_interp = output_dir / f"img_{index:08d}.png"
|
||||
interp_frames[i].rename(output_interp)
|
||||
index += 1
|
||||
|
||||
|
||||
def get_vram_available(device: torch.device) -> int:
|
||||
"""Returns the available VRAM in bytes."""
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_properties(
|
||||
device
|
||||
).total_memory - torch.cuda.memory_allocated(device)
|
||||
elif device.type == "mps" and torch.mps.is_available():
|
||||
# MPS does not provide a way to query available memory, so we return a large number to avoid issues
|
||||
return torch.mps.recommended_max_memory()
|
||||
else:
|
||||
return 1
|
||||
def build_file_list(moved_dir: str, list_path: str):
|
||||
import os
|
||||
moved_dir = Path(moved_dir)
|
||||
frames = sorted(moved_dir.glob("img_*.png"))
|
||||
print(frames[0])
|
||||
|
||||
with open(list_path, "w") as f:
|
||||
for frame in frames:
|
||||
f.write(f"file '{os.path.abspath(frame)}'\n")
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Detects and returns the best available device for PyTorch computation.
|
||||
def build_ffmpeg_file_list(frames_dir: str, interpolated_dir: str, list_path: str):
|
||||
frames = sorted(Path(frames_dir).glob("img_*.png"))
|
||||
interps = sorted(Path(interpolated_dir).glob("img_*.png"))
|
||||
|
||||
Returns:
|
||||
torch.device: CUDA device if available, MPS device for Apple Silicon if available, otherwise CPU.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
logging.info("Using CUDA-enabled GPU")
|
||||
return torch.device("cuda")
|
||||
elif torch.mps.is_available():
|
||||
logging.info("Using Apple Silicon GPU (MPS)")
|
||||
return torch.device("mps")
|
||||
logging.info("No GPU available, using CPU")
|
||||
return torch.device("cpu")
|
||||
if len(interps) != len(frames) - 1:
|
||||
raise ValueError("Interpolated frames must be N-1")
|
||||
|
||||
with open(list_path, "w") as f:
|
||||
for i in range(len(frames)):
|
||||
f.write(f"file '{frames[i].resolve().as_posix()}'\n")
|
||||
|
||||
if i < len(interps):
|
||||
f.write(f"file '{interps[i].resolve().as_posix()}'\n")
|
||||
|
||||
|
||||
class ImageInterpolator:
|
||||
def __init__(self, device: torch.device, anchor: Anchor, model_runner: ModelRunner):
|
||||
self.device = device
|
||||
self.anchor = anchor
|
||||
self.vram_available = get_vram_available(device)
|
||||
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
||||
self.model_runner = model_runner
|
||||
logging.debug(
|
||||
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
||||
)
|
||||
def merge_with_ffmpeg(
|
||||
original_video: str,
|
||||
file_list: str,
|
||||
output_video: str,
|
||||
):
|
||||
cap = cv2.VideoCapture(original_video)
|
||||
|
||||
def interpolate(self, image1: Path, image2: Path, output_path: Path):
|
||||
logging.debug(f"Reading images: {image1} and {image2}")
|
||||
tensor1 = img2tensor(utils.read(image1)).to(self.device)
|
||||
tensor2 = img2tensor(utils.read(image2)).to(self.device)
|
||||
logging.debug(
|
||||
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||
)
|
||||
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
||||
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
||||
h, w = tensor1.shape[2], tensor1.shape[3]
|
||||
logging.debug(f"Interpolating images of size: {h}x{w}")
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Cannot open original video")
|
||||
|
||||
scale = self.scale(h, w)
|
||||
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
||||
padding = int(16 / scale)
|
||||
logging.debug(f"Calculated padding: {padding} pixels")
|
||||
padder = InputPadder(tensor1.shape, divisor=padding)
|
||||
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
||||
logging.debug(
|
||||
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
||||
)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap.release()
|
||||
|
||||
tensor1_padded = tensor1_padded.to(self.device)
|
||||
tensor2_padded = tensor2_padded.to(self.device)
|
||||
logging.debug("Running model inference for interpolation")
|
||||
with torch.no_grad():
|
||||
interpolated = self.model_runner.model(
|
||||
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
||||
)["imgt_pred"]
|
||||
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||
(interpolated,) = padder.unpad(interpolated)
|
||||
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||
utils.write(output_path, tensor2img(interpolated.cpu()))
|
||||
logging.debug(f"Saved interpolated image to: {output_path}")
|
||||
new_fps = Decimal(fps * 2)
|
||||
|
||||
def scale(self, height: int, width: int) -> float:
|
||||
scale = (
|
||||
self.anchor.resolution
|
||||
/ (height * width)
|
||||
* np.sqrt(
|
||||
(self.vram_available - self.anchor.memory_bias) / self.anchor.memory
|
||||
)
|
||||
)
|
||||
scale = 1 if scale > 1 else scale
|
||||
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
||||
if scale < 1:
|
||||
logging.info(
|
||||
f"Due to the limited VRAM, the video will be scaled by {scale:.2f}"
|
||||
)
|
||||
return scale
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-r", str(new_fps.quantize(Decimal("1.0000000000"))),
|
||||
"-f", "concat",
|
||||
"-safe", "0",
|
||||
"-i", file_list,
|
||||
"-c:v", "libx264rgb",
|
||||
output_video,
|
||||
]
|
||||
print("Running ffmpeg command:", " ".join(cmd))
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
|
||||
def video_frames_to_disk_generator(
|
||||
video_path: str | Path,
|
||||
output_dir: str | Path,
|
||||
chunk_seconds: int = 10
|
||||
):
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frames_per_chunk = int(fps * chunk_seconds)
|
||||
|
||||
frame_index = 0
|
||||
|
||||
while True:
|
||||
paths = []
|
||||
|
||||
for _ in range(frames_per_chunk):
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
cap.release()
|
||||
return
|
||||
|
||||
frame_path = output_dir / f"img_{frame_index:08d}.png"
|
||||
cv2.imwrite(str(frame_path), frame)
|
||||
|
||||
paths.append(frame_path)
|
||||
frame_index += 1
|
||||
|
||||
yield tuple(paths)
|
||||
|
||||
|
||||
def main():
|
||||
start = perf_counter()
|
||||
logging.info("Starting video interpolation process")
|
||||
config_path = Path("src/config/AMT-G.yaml")
|
||||
ckpt_path = Path("src/pretrained/amt-g.pth")
|
||||
image1_path = Path("source/img0.png")
|
||||
image2_path = Path("source/img1.png")
|
||||
image3_path = Path("source/img2.png")
|
||||
output_path1 = Path("output/interpolated_image1.png")
|
||||
output_path2 = Path("output/interpolated_image2.png")
|
||||
video_path = Path("source/video.mp4")
|
||||
output_dir = Path("output/frames")
|
||||
output_interpolated_dir = Path("output/interpolated")
|
||||
output_interpolated_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = get_device()
|
||||
model_runner = ModelRunner(config_path, ckpt_path, device)
|
||||
@@ -167,9 +160,44 @@ def main():
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
interpolator = ImageInterpolator(device, anchor, model_runner)
|
||||
interpolator.interpolate(image1_path, image2_path, output_path1)
|
||||
interpolator.interpolate(image2_path, image3_path, output_path2)
|
||||
|
||||
loaded_time = perf_counter() - start
|
||||
logging.info(f"Model loaded and initialized in {loaded_time:.2f} seconds")
|
||||
|
||||
prev_frame_path = None
|
||||
frame_count = 0
|
||||
for frame_paths in video_frames_to_disk_generator(video_path, output_dir):
|
||||
logging.info(f"Processing frames: {frame_paths}")
|
||||
|
||||
if prev_frame_path is not None:
|
||||
img1 = prev_frame_path[-1]
|
||||
img2 = frame_paths[0]
|
||||
output_path = output_interpolated_dir / f"img_{frame_count:08d}.png"
|
||||
interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
frame_count += 1
|
||||
for i in tqdm(range(len(frame_paths) - 1), desc="Interpolating frames"):
|
||||
img1 = frame_paths[i]
|
||||
img2 = frame_paths[i + 1]
|
||||
output_path = output_interpolated_dir / f"img_{frame_count:08d}.png"
|
||||
interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
frame_count += 1
|
||||
prev_frame_path = frame_paths
|
||||
total_time = perf_counter() - start
|
||||
logging.info(f"Video interpolation completed in {total_time:.2f} seconds")
|
||||
|
||||
|
||||
def builder():
|
||||
frames_dir = "output/frames"
|
||||
interpolated_dir = "output/interpolated"
|
||||
list_path = "file_list.txt"
|
||||
video_path = "source/video.mp4"
|
||||
output_video = "output/interpolated_video.mp4"
|
||||
build_file_list('output/moved_frames', list_path)
|
||||
merge_with_ffmpeg(video_path, list_path, output_video)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user