Compare commits
5 Commits
c72e34f9dc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7acd66974 | ||
|
|
2d67b72128 | ||
| c91cf6b53a | |||
|
|
61f8e0abe1 | ||
|
|
faf7aa8e81 |
192
main.py
192
main.py
@@ -1,193 +1,7 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tqdm
|
||||
|
||||
from src.runner import run
|
||||
from src.config import presets
|
||||
from src.utils.fs import FileSystem
|
||||
from src.utils.video import VideoMaker
|
||||
from src.interpolator import (
|
||||
ImageInterpolator,
|
||||
Anchor,
|
||||
get_device,
|
||||
get_vram_available,
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
def performing_warning_message(device: "torch.device"):
|
||||
if device.type in ("cpu", "mps"):
|
||||
if device.type == "mps":
|
||||
logging.warning(
|
||||
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"Running on CPU may be very slow. Consider using a GPU for better performance."
|
||||
)
|
||||
elif device.type == "cuda":
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
|
||||
def init_fs(base_path: Path) -> FileSystem:
|
||||
fs = FileSystem(base_path)
|
||||
fs.clear_directory(fs.frames_path)
|
||||
fs.clear_directory(fs.interpolated_path)
|
||||
fs.clear_directory(fs.moved_path)
|
||||
fs.clear_directory(fs.video_part_path)
|
||||
return fs
|
||||
|
||||
|
||||
def init_video_maker() -> VideoMaker:
|
||||
return VideoMaker()
|
||||
|
||||
|
||||
def init_device() -> "torch.device":
|
||||
device = get_device()
|
||||
performing_warning_message(device)
|
||||
vram_available = get_vram_available(device)
|
||||
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
|
||||
return device
|
||||
|
||||
|
||||
def init_anchor(device: "torch.device") -> Anchor:
|
||||
if device.type in ("cpu", "mps"):
|
||||
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
|
||||
elif device.type == "cuda":
|
||||
return Anchor(
|
||||
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
|
||||
def init_model_runner(
|
||||
config: Path, checkpoint_path: Path, device: "torch.device"
|
||||
) -> ModelRunner:
|
||||
return ModelRunner(config, checkpoint_path, device)
|
||||
|
||||
|
||||
def init_interpolator(
|
||||
model_runner: ModelRunner, device: "torch.device"
|
||||
) -> ImageInterpolator:
|
||||
anchor = init_anchor(device)
|
||||
return ImageInterpolator(device, anchor, model_runner)
|
||||
|
||||
|
||||
class InterpolationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
config: Path,
|
||||
checkpoint_path: Path,
|
||||
base_path: Path,
|
||||
):
|
||||
self.fs = init_fs(base_path)
|
||||
self.video_maker = init_video_maker()
|
||||
self.device = init_device()
|
||||
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
|
||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||
|
||||
def run(self, video_path: Path, output_video: str):
|
||||
prev_frame_path = None
|
||||
frame_count = 0
|
||||
part = 0
|
||||
source_frame_length = 0
|
||||
chunk_seconds = 10
|
||||
length = self.video_maker.get_video_duration(video_path)
|
||||
last_part_seconds = 1 if length % chunk_seconds else 0
|
||||
total_parts = int(length // chunk_seconds) + last_part_seconds
|
||||
fps = self.video_maker.get_fps(video_path)
|
||||
logging.info(f"Video FPS: {fps}")
|
||||
fps *= 2 # Doubling FPS
|
||||
for frame_paths in self.video_maker.video_to_frames_generator(
|
||||
video_path, self.fs.frames_path, chunk_seconds
|
||||
):
|
||||
logging.info(f"Processing frames: {len(frame_paths)}")
|
||||
if prev_frame_path is not None:
|
||||
img1 = prev_frame_path[-1]
|
||||
img2 = frame_paths[0]
|
||||
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
|
||||
self.interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
self._merge_frames_to_video(
|
||||
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
||||
fps,
|
||||
source_frame_length=source_frame_length,
|
||||
)
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
frame_count += 1
|
||||
part += 1
|
||||
for i in tqdm.tqdm(
|
||||
range(len(frame_paths) - 1),
|
||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||
):
|
||||
img1 = frame_paths[i]
|
||||
img2 = frame_paths[i + 1]
|
||||
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
|
||||
self.interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
frame_count += 1
|
||||
source_frame_length = len(frame_paths)
|
||||
prev_frame_path = frame_paths
|
||||
|
||||
self._merge_frames_to_video(
|
||||
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
||||
fps,
|
||||
source_frame_length=source_frame_length,
|
||||
)
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
self._merge_video_parts(self.fs.output_path / output_video)
|
||||
logging.info(
|
||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||
)
|
||||
|
||||
def _merge_frames_to_video(
|
||||
self, output_video: Path, fps: float, source_frame_length: int = 0
|
||||
):
|
||||
self._move_frames(source_frame_length)
|
||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||
|
||||
def _merge_video_parts(self, output_video: Path):
|
||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||
self.fs.clear_directory(self.fs.video_part_path)
|
||||
|
||||
def _move_frames(self, source_frame_length: int = 0):
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
src_frames = sorted(self.fs.frames_path.glob("*.png"))
|
||||
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
|
||||
index = 0
|
||||
for i in range(source_frame_length):
|
||||
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
src_frames[i].rename(moved_frame_path)
|
||||
index += 1
|
||||
if i < len(interpolated_frames):
|
||||
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
interpolated_frames[i].rename(moved_interpolated_path)
|
||||
index += 1
|
||||
logging.info(
|
||||
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
|
||||
)
|
||||
|
||||
|
||||
def runner(
|
||||
base_path: Path,
|
||||
video_path: Path,
|
||||
output_video: str,
|
||||
preset: presets.Preset = presets.LARGE,
|
||||
):
|
||||
pipeline = InterpolationPipeline(
|
||||
config=preset.config,
|
||||
checkpoint_path=preset.checkpoint,
|
||||
base_path=base_path,
|
||||
)
|
||||
pipeline.run(video_path, output_video)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -216,11 +30,11 @@ def main():
|
||||
default="global",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
runner(
|
||||
run(
|
||||
base_path=Path(args.base_path),
|
||||
video_path=Path(args.video_path),
|
||||
output_video=args.output,
|
||||
preset=getattr(presets, args.preset.upper())
|
||||
preset=getattr(presets, args.preset.upper()),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: src.networks.AMT-G.Model
|
||||
name: AMT-G.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
|
||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: src.networks.AMT-L.Model
|
||||
name: AMT-L.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
|
||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: src.networks.AMT-S.Model
|
||||
name: AMT-S.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
|
||||
@@ -4,11 +4,10 @@ from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from imageio import imread, imwrite
|
||||
|
||||
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||
from src.utils.build import build_from_cfg
|
||||
from src.utils.padder import InputPadder
|
||||
from .utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||
from .utils.build import build_from_cfg
|
||||
from .utils.padder import InputPadder
|
||||
|
||||
|
||||
class Anchor:
|
||||
@@ -83,7 +82,7 @@ class ImageInterpolator:
|
||||
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
||||
)
|
||||
|
||||
def interpolate(self, image1: Path, image2: Path, output_path: Path):
|
||||
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Interpolates between two images and saves the result.
|
||||
Args:
|
||||
@@ -92,8 +91,8 @@ class ImageInterpolator:
|
||||
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
||||
"""
|
||||
logging.debug(f"Reading images: {image1} and {image2}")
|
||||
tensor1 = img2tensor(imread(image1)).to(self.device)
|
||||
tensor2 = img2tensor(imread(image2)).to(self.device)
|
||||
tensor1 = img2tensor(image1).to(self.device)
|
||||
tensor2 = img2tensor(image2).to(self.device)
|
||||
logging.debug(
|
||||
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||
)
|
||||
@@ -122,8 +121,7 @@ class ImageInterpolator:
|
||||
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||
(interpolated,) = padder.unpad(interpolated)
|
||||
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||
imwrite(output_path, tensor2img(interpolated.cpu()))
|
||||
logging.debug(f"Saved interpolated image to: {output_path}")
|
||||
return tensor2img(interpolated.cpu())
|
||||
|
||||
def scale(self, height: int, width: int) -> float:
|
||||
scale = (
|
||||
|
||||
@@ -2,10 +2,10 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from src.networks.blocks.feat_enc import LargeEncoder
|
||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from .blocks.feat_enc import LargeEncoder
|
||||
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from src.networks.blocks.feat_enc import BasicEncoder
|
||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from .blocks.feat_enc import BasicEncoder
|
||||
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
|
||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
|
||||
from src.networks.blocks.feat_enc import SmallEncoder
|
||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
from .blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
|
||||
from .blocks.feat_enc import SmallEncoder
|
||||
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.flow_utils import warp
|
||||
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
|
||||
from ..utils.flow_utils import warp
|
||||
from .blocks.ifrnet import convrelu, resize, ResBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.utils.flow_utils import warp
|
||||
from ...utils.flow_utils import warp
|
||||
|
||||
|
||||
def resize(x, scale_factor):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.flow_utils import warp
|
||||
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
|
||||
from ...utils.flow_utils import warp
|
||||
from .ifrnet import convrelu, resize, ResBlock
|
||||
|
||||
|
||||
def multi_flow_combine(
|
||||
|
||||
196
src/runner.py
Normal file
196
src/runner.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cv2 import imwrite
|
||||
import tqdm
|
||||
|
||||
from .config import presets
|
||||
from .utils.fs import FileSystem
|
||||
from .utils.video import VideoMaker
|
||||
from .interpolator import (
|
||||
ImageInterpolator,
|
||||
Anchor,
|
||||
get_device,
|
||||
get_vram_available,
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def performing_warning_message(device: "torch.device"):
|
||||
if device.type in ("cpu", "mps"):
|
||||
if device.type == "mps":
|
||||
logging.warning(
|
||||
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"Running on CPU may be very slow. Consider using a GPU for better performance."
|
||||
)
|
||||
elif device.type == "cuda":
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
|
||||
def init_fs(base_path: Path) -> FileSystem:
|
||||
fs = FileSystem(base_path)
|
||||
fs.clear_directory(fs.frames_path)
|
||||
fs.clear_directory(fs.interpolated_path)
|
||||
fs.clear_directory(fs.moved_path)
|
||||
fs.clear_directory(fs.video_part_path)
|
||||
return fs
|
||||
|
||||
|
||||
def init_video_maker() -> VideoMaker:
|
||||
return VideoMaker()
|
||||
|
||||
|
||||
def init_device() -> "torch.device":
|
||||
device = get_device()
|
||||
performing_warning_message(device)
|
||||
vram_available = get_vram_available(device)
|
||||
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
||||
return device
|
||||
|
||||
|
||||
def init_anchor(device: "torch.device") -> Anchor:
|
||||
if device.type in ("cpu", "mps"):
|
||||
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
|
||||
elif device.type == "cuda":
|
||||
return Anchor(
|
||||
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
|
||||
def init_model_runner(
|
||||
config: Path, checkpoint_path: Path, device: "torch.device"
|
||||
) -> ModelRunner:
|
||||
return ModelRunner(config, checkpoint_path, device)
|
||||
|
||||
|
||||
def init_interpolator(
|
||||
model_runner: ModelRunner, device: "torch.device"
|
||||
) -> ImageInterpolator:
|
||||
anchor = init_anchor(device)
|
||||
return ImageInterpolator(device, anchor, model_runner)
|
||||
|
||||
|
||||
class InterpolationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
config: Path,
|
||||
checkpoint_path: Path,
|
||||
base_path: Path,
|
||||
):
|
||||
self.fs = init_fs(base_path)
|
||||
self.video_maker = init_video_maker()
|
||||
self.device = init_device()
|
||||
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
|
||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||
|
||||
def run(self, video_path: Path, output_video: str):
|
||||
prev_frames: tuple["np.ndarray", ...] = tuple()
|
||||
interpolated_frames: list["np.ndarray"] = []
|
||||
part = 0
|
||||
chunk_seconds = 10
|
||||
length = self.video_maker.get_video_duration(video_path)
|
||||
last_part_seconds = 1 if length % chunk_seconds else 0
|
||||
total_parts = int(length // chunk_seconds) + last_part_seconds
|
||||
fps = self.video_maker.get_fps(video_path)
|
||||
logging.info(f"Video FPS: {fps}")
|
||||
fps *= 2 # Doubling FPS
|
||||
width, height = self.video_maker.get_size(video_path)
|
||||
for frames in self.video_maker.video_to_frames_generator(
|
||||
video_path, self.fs.frames_path, chunk_seconds
|
||||
):
|
||||
logging.info(f"Processing frames: {len(frames)}")
|
||||
if prev_frames:
|
||||
img1 = prev_frames[-1]
|
||||
img2 = frames[0]
|
||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||
interpolated_frames.append(img1_2)
|
||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||
self.video_maker.images_to_video_pipeline(
|
||||
generator, part_path, width, height, fps
|
||||
)
|
||||
interpolated_frames = []
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
part += 1
|
||||
for i in tqdm.tqdm(
|
||||
range(len(frames) - 1),
|
||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||
):
|
||||
img1 = frames[i]
|
||||
img2 = frames[i + 1]
|
||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||
interpolated_frames.append(img1_2)
|
||||
prev_frames = frames
|
||||
|
||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||
self.video_maker.images_to_video_pipeline(
|
||||
generator, part_path, width, height, fps
|
||||
)
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
self._merge_video_parts(self.fs.output_path / output_video)
|
||||
logging.info(
|
||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||
)
|
||||
|
||||
def _save_images(
|
||||
self,
|
||||
source: tuple["np.ndarray", ...],
|
||||
interpolated: list["np.ndarray"],
|
||||
):
|
||||
logging.info("Saving images...")
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
index = 0
|
||||
for i, frame in enumerate(source):
|
||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
index += 1
|
||||
imwrite(name, frame)
|
||||
if i < len(interpolated):
|
||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
index += 1
|
||||
imwrite(name, interpolated[i])
|
||||
logging.info("Success...")
|
||||
|
||||
def _merge_frames_to_video(self, output_video: Path, fps: float):
|
||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||
|
||||
def _merge_video_parts(self, output_video: Path):
|
||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||
self.fs.clear_directory(self.fs.video_part_path)
|
||||
|
||||
def _frame_generator(
|
||||
self,
|
||||
source: tuple["np.ndarray", ...],
|
||||
interpolated: list["np.ndarray"],
|
||||
):
|
||||
for i, frame in enumerate(source):
|
||||
yield frame
|
||||
if i < len(interpolated):
|
||||
yield interpolated[i]
|
||||
|
||||
|
||||
def run(
|
||||
base_path: Path,
|
||||
video_path: Path,
|
||||
output_video: str,
|
||||
preset: presets.Preset = presets.LARGE,
|
||||
):
|
||||
pipeline = InterpolationPipeline(
|
||||
config=preset.config,
|
||||
checkpoint_path=preset.checkpoint,
|
||||
base_path=base_path,
|
||||
)
|
||||
pipeline.run(video_path, output_video)
|
||||
@@ -1,16 +1,19 @@
|
||||
import importlib
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from ..networks import AMT_G, AMT_L, AMT_S
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def base_build_fn(module: str, cls: str, params: dict):
|
||||
return getattr(importlib.import_module(module, package=None), cls)(**params)
|
||||
|
||||
|
||||
def build_from_cfg(config: "DictConfig"):
|
||||
packages = {
|
||||
"AMT-G": AMT_G,
|
||||
"AMT-L": AMT_L,
|
||||
"AMT-S": AMT_S
|
||||
}
|
||||
|
||||
module, cls = config["name"].rsplit(".", 1)
|
||||
params: dict = config.get("params", {})
|
||||
return base_build_fn(module, cls, params)
|
||||
return getattr(packages[module], cls)(**params)
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ import os
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable
|
||||
|
||||
import cv2
|
||||
from typing import Generator
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VideoMaker:
|
||||
@@ -35,7 +36,7 @@ class VideoMaker:
|
||||
with open(file, "w") as f:
|
||||
for video in videos:
|
||||
f.write(f"file '{video}'\n")
|
||||
cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}"
|
||||
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
|
||||
logging.info(f"Running command: {cmd}")
|
||||
result = self.run_command(cmd)
|
||||
if result != 0:
|
||||
@@ -66,7 +67,13 @@ class VideoMaker:
|
||||
|
||||
def run_command(self, cmd: str) -> int:
|
||||
try:
|
||||
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return 0
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Command failed with error: {e}")
|
||||
@@ -74,7 +81,7 @@ class VideoMaker:
|
||||
|
||||
def video_to_frames_generator(
|
||||
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
|
||||
) -> Generator[tuple[Path, ...], None, None]:
|
||||
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
||||
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
@@ -85,21 +92,56 @@ class VideoMaker:
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frames_per_chunk = int(fps * chunk_seconds)
|
||||
|
||||
frame_index = 0
|
||||
|
||||
while True:
|
||||
paths = []
|
||||
|
||||
for _ in range(frames_per_chunk):
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
cap.release()
|
||||
return
|
||||
|
||||
frame_path = output_dir / f"img_{frame_index:08d}.png"
|
||||
cv2.imwrite(str(frame_path), frame)
|
||||
|
||||
paths.append(frame_path)
|
||||
frame_index += 1
|
||||
|
||||
paths.append(frame)
|
||||
yield tuple(paths)
|
||||
|
||||
def images_to_video_pipeline(
|
||||
self,
|
||||
frames: Iterable[np.ndarray],
|
||||
output_path: Path,
|
||||
width: int,
|
||||
height: int,
|
||||
fps: float,
|
||||
):
|
||||
pipeline = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f", "rawvideo",
|
||||
"-vcodec", "rawvideo",
|
||||
"-pix_fmt", "bgr24",
|
||||
"-s", f"{width}x{height}",
|
||||
"-r", str(fps),
|
||||
"-i", "-",
|
||||
"-an",
|
||||
"-vcodec", "libx264",
|
||||
"-pix_fmt", "yuv420p",
|
||||
str(output_path),
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL
|
||||
)
|
||||
if pipeline.stdin is None:
|
||||
raise Exception("STDIN closed")
|
||||
for frame in frames:
|
||||
pipeline.stdin.write(frame.tobytes())
|
||||
|
||||
pipeline.stdin.close()
|
||||
pipeline.wait()
|
||||
|
||||
def get_size(self, video_path: Path) -> tuple[int, int]:
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
cap.release()
|
||||
return width, height
|
||||
|
||||
Reference in New Issue
Block a user