Files
AMT-Apple/main.py

197 lines
6.8 KiB
Python

import logging
from pathlib import Path
from typing import TYPE_CHECKING
import tqdm
from src.config import presets
from src.utils.fs import FileSystem
from src.utils.video import VideoMaker
from src.interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frame_path = None
frame_count = 0
part = 0
source_frame_length = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
for frame_paths in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frame_paths)}")
if prev_frame_path is not None:
img1 = prev_frame_path[-1]
img2 = frame_paths[0]
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
self.interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4",
fps,
source_frame_length=source_frame_length,
)
logging.info(f"Finished processing part {part:08d}")
frame_count += 1
part += 1
for i in tqdm.tqdm(
range(len(frame_paths) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frame_paths[i]
img2 = frame_paths[i + 1]
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
self.interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
frame_count += 1
source_frame_length = len(frame_paths)
prev_frame_path = frame_paths
self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4",
fps,
source_frame_length=source_frame_length,
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _merge_frames_to_video(
self, output_video: Path, fps: float, source_frame_length: int = 0
):
self._move_frames(source_frame_length)
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _move_frames(self, source_frame_length: int = 0):
self.fs.clear_directory(self.fs.moved_path)
src_frames = sorted(self.fs.frames_path.glob("*.png"))
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
index = 0
for i in range(source_frame_length):
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
src_frames[i].rename(moved_frame_path)
index += 1
if i < len(interpolated_frames):
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
interpolated_frames[i].rename(moved_interpolated_path)
index += 1
logging.info(
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
)
def main(preset: presets.Preset = presets.LARGE):
base_path = Path("output")
video_path = Path("example/video.mp4")
output_video = "interpolated_video.mp4"
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
main()