4 Commits

Author SHA1 Message Date
Viner Abubakirov
7addcf051c Попытка добавить onnx в работе с nvidia 2026-04-15 17:53:06 +05:00
Viner Abubakirov
0c871c2314 Refactor image tensor conversion and model inference in interpolator.py and torch.py 2026-04-13 18:56:00 +05:00
Viner Abubakirov
61f8e0abe1 Поменял, метод формирования видео 2026-04-03 17:03:10 +05:00
Viner Abubakirov
faf7aa8e81 Начал менять логику pipeline 2026-04-02 18:26:36 +05:00
15 changed files with 2611 additions and 464 deletions

BIN
example/frame_01.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
example/frame_02.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

122
main.py
View File

@@ -2,6 +2,7 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
from src.config import presets
@@ -18,6 +19,7 @@ from src.interpolator import (
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
@@ -53,7 +55,7 @@ def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
@@ -68,10 +70,8 @@ def init_anchor(device: "torch.device") -> Anchor:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_model_runner(preset: presets.Preset, device: "torch.device") -> ModelRunner:
return ModelRunner(preset, device)
def init_interpolator(
@@ -84,63 +84,58 @@ def init_interpolator(
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
preset: presets.Preset,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.model_runner = init_model_runner(preset, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frame_path = None
frame_count = 0
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
source_frame_length = 0
chunk_seconds = 10
chunk_seconds = 1
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
for frame_paths in self.video_maker.video_to_frames_generator(
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frame_paths)}")
if prev_frame_path is not None:
img1 = prev_frame_path[-1]
img2 = frame_paths[0]
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
self.interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4",
fps,
source_frame_length=source_frame_length,
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
frame_count += 1
part += 1
for i in tqdm.tqdm(
range(len(frame_paths) - 1),
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frame_paths[i]
img2 = frame_paths[i + 1]
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
self.interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
frame_count += 1
source_frame_length = len(frame_paths)
prev_frame_path = frame_paths
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
self._merge_frames_to_video(
self.fs.video_part_path / f"video_{part:08d}.mp4",
fps,
source_frame_length=source_frame_length,
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
@@ -148,32 +143,40 @@ class InterpolationPipeline:
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _merge_frames_to_video(
self, output_video: Path, fps: float, source_frame_length: int = 0
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
self._move_frames(source_frame_length)
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _move_frames(self, source_frame_length: int = 0):
self.fs.clear_directory(self.fs.moved_path)
src_frames = sorted(self.fs.frames_path.glob("*.png"))
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
index = 0
for i in range(source_frame_length):
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
src_frames[i].rename(moved_frame_path)
index += 1
if i < len(interpolated_frames):
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
interpolated_frames[i].rename(moved_interpolated_path)
index += 1
logging.info(
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
@@ -183,8 +186,7 @@ def runner(
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
preset=preset,
base_path=base_path,
)
pipeline.run(video_path, output_video)
@@ -220,7 +222,7 @@ def main():
base_path=Path(args.base_path),
video_path=Path(args.video_path),
output_video=args.output,
preset=getattr(presets, args.preset.upper())
preset=getattr(presets, args.preset.upper()),
)

8
onnx_export.py Normal file
View File

@@ -0,0 +1,8 @@
import torch
from src.export_to_onnx import export_to_onnx
from src.config.presets import SMALL
if __name__ == "__main__":
device = torch.device("cuda")
export_to_onnx(SMALL, "src/pretrained/amt_s.onnx", device)

View File

@@ -7,8 +7,14 @@ requires-python = ">=3.12"
dependencies = [
"imageio>=2.37.3",
"numpy>=2.4.4",
"nvidia-modelopt[all]>=0.33.1",
"omegaconf>=2.3.0",
"onnx>=1.21.0",
"onnxscript>=0.6.2",
"opencv-python>=4.13.0.92",
"torch>=2.11.0",
"tensorrt>=10.16.1.11",
"torch==2.5.1",
"torch-tensorrt>=2.5.0",
"torchvision>=0.20.1",
"tqdm>=4.67.3",
]

View File

@@ -1,3 +1,4 @@
from typing import Literal
from pathlib import Path
from dataclasses import dataclass
@@ -6,6 +7,7 @@ from dataclasses import dataclass
class Preset:
config: Path
checkpoint: Path
onnx: Path | None = None
SMALL = Preset(
@@ -19,6 +21,6 @@ LARGE = Preset(
)
GLOBAL = Preset(
config=Path("src/config/AMT-g.yaml"),
config=Path("src/config/AMT-G.yaml"),
checkpoint=Path("src/pretrained/amt-g.pth"),
)

29
src/export_to_onnx.py Normal file
View File

@@ -0,0 +1,29 @@
import torch
import torchvision
torchvision.disable_beta_transforms_warning()
torch.backends.cudnn.enabled = False
from .interpolator import ModelRunner
from .config.presets import Preset
def export_to_onnx(preset: Preset, output_path: str, device: torch.device):
model_runner = ModelRunner(preset, device)
# model_runner.model.eval()
dummy_input = model_runner.get_dummy_input()
torch.onnx.export(
model_runner.model,
dummy_input,
output_path,
opset_version=17,
input_names=['img0', 'img1', 'embt'],
output_names=["imgt_pred"],
dynamic_axes={
"img0": {0: "batch", 2: "height", 3: "width"},
"img1": {0: "batch", 2: "height", 3: "width"},
"embt": {0: "batch", 2: "height", 3: "width"},
"imgt_pred": {0: "batch", 2: "height", 3: "width"},
},
dynamo=True,
use_external_data_format=False,
)

View File

@@ -1,14 +1,16 @@
import logging
from pathlib import Path
from typing import Optional
from cv2 import imread
import torch
import onnxruntime as ort
import numpy as np
from omegaconf import OmegaConf, DictConfig
from imageio import imread, imwrite
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.config.presets import Preset
from src.utils.torch import img2tensor, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
class Anchor:
@@ -21,8 +23,27 @@ class Anchor:
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
class ONNXWrapper:
def __init__(self, path):
self.session = ort.InferenceSession(path)
self.input_names = [i.name for i in self.session.get_inputs()]
self.output_names = [o.name for o in self.session.get_outputs()]
def __call__(self, tensor1, tensor2, embt):
inputs = {
self.input_names[0]: tensor1.cpu().numpy(),
self.input_names[1]: tensor2.cpu().numpy(),
self.input_names[2]: embt.cpu().numpy(),
}
outputs = self.session.run(self.output_names, inputs)
return {"imgt_pred": torch.from_numpy(outputs[0])}
class ModelRunner:
def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
def __init__(self, preset: Preset, device: torch.device) -> None:
"""Initializes the ModelRunner with configuration and checkpoint.
Args:
@@ -30,17 +51,73 @@ class ModelRunner:
ckpt_path (Path): Path to model checkpoint in .pth format
device (torch.device): Device to load the model on
"""
omega_config = OmegaConf.load(config)
self.model: Optional[torch.nn.Module] = None
self.session: Optional[ONNXWrapper] = None
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
if preset.onnx:
self.session = ONNXWrapper(preset.onnx)
self.device = device
self.embt = self.embt.cpu().numpy()
return
omega_config = OmegaConf.load(preset.config)
network_config: DictConfig = omega_config.network
logging.info(
f"Loaded network configuration: {network_config} from [{ckpt_path}]"
f"Loaded network configuration: {network_config} from [{preset.checkpoint}]"
)
model = build_from_cfg(network_config)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
checkpoint = torch.load(
preset.checkpoint, map_location=device, weights_only=False
)
model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device())
model = model.to(device)
model.eval()
# self.model = torch.compile(model)
self.model = model
self.device = device
# self.model = torch.compile(model, backend="tensorrt")
if logging.getLogger().isEnabledFor(logging.DEBUG):
for name, param in self.model.named_parameters():
logging.debug(
f"Parameter: {name}, shape: {param.shape}, dtype: {param.dtype}"
)
def get_dummy_input(self):
"""Generates a dummy input tensor for ONNX export."""
return (
img2tensor(imread(filename="example/frame_01.png"), self.device),
img2tensor(imread(filename="example/frame_02.png"), self.device),
self.embt,
)
def run(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""Runs the model inference to interpolate between two images.
Args:
image1 (np.ndarray): First input image as a NumPy array
image2 (np.ndarray): Second input image as a NumPy array
Returns:
np.ndarray: Interpolated image as a NumPy array
"""
if self.session:
image1 = img2tensor(image1, self.device).cpu().numpy()
image2 = img2tensor(image2, self.device).cpu().numpy()
inputs = {
"img0": image1,
"img1": image2,
"embt": self.embt,
}
outputs = self.session.session.run(None, inputs)
return outputs[0]
tensor1 = img2tensor(image1, self.device)
tensor2 = img2tensor(image2, self.device)
with torch.no_grad():
with torch.amp.autocast(self.device.type):
interpolated = self.model(tensor1, tensor2, self.embt)["imgt_pred"]
return tensor2img(interpolated.cpu())
def get_vram_available(device: torch.device) -> int:
@@ -77,13 +154,12 @@ class ImageInterpolator:
self.device = device
self.anchor = anchor
self.vram_available = get_vram_available(device)
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
self.model_runner = model_runner
logging.debug(
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
)
def interpolate(self, image1: Path, image2: Path, output_path: Path):
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""
Interpolates between two images and saves the result.
Args:
@@ -91,39 +167,7 @@ class ImageInterpolator:
image2 (Path): Path to the second input image (only png and jpg formats are supported)
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(imread(image1)).to(self.device)
tensor2 = img2tensor(imread(image2)).to(self.device)
logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
)
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation")
with torch.no_grad():
interpolated = self.model_runner.model(
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
)["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
imwrite(output_path, tensor2img(interpolated.cpu()))
logging.debug(f"Saved interpolated image to: {output_path}")
return self.model_runner.run(image1, image2)
def scale(self, height: int, width: int) -> float:
scale = (

View File

@@ -67,7 +67,14 @@ class Model(nn.Module):
flow = torch.cat([flow0, flow1], dim=1)
return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
def forward(
self,
img0: torch.Tensor,
img1: torch.Tensor,
embt: torch.Tensor,
):
scale_factor = 1.0
eval = False
mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)

View File

@@ -1,40 +1,44 @@
from typing import Any
import torch
import torch.nn as nn
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
@@ -46,8 +50,8 @@ class BottleneckBlock(nn.Module):
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
@@ -58,38 +62,40 @@ class BottleneckBlock(nn.Module):
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
return self.relu(x + y)
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
@@ -100,8 +106,8 @@ class ResidualBlock(nn.Module):
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
@@ -111,31 +117,31 @@ class ResidualBlock(nn.Module):
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
return self.relu(x + y)
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
@@ -147,7 +153,7 @@ class SmallEncoder(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
@@ -162,14 +168,15 @@ class SmallEncoder(nn.Module):
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
def forward(
self, x: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]
):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = None
if is_list := isinstance(x, tuple) or isinstance(x, list):
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x: torch.Tensor = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
@@ -183,33 +190,37 @@ class SmallEncoder(nn.Module):
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
if is_list and batch_dim is not None:
return torch.split(x, [batch_dim, batch_dim], dim=0)
return x
def __call__(self, *args: Any, **kwds: Any) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return super().__call__(*args, **kwds)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(72, stride=2)
self.layer3 = self._make_layer(128, stride=2)
@@ -222,7 +233,7 @@ class BasicEncoder(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
@@ -237,7 +248,6 @@ class BasicEncoder(nn.Module):
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
@@ -264,21 +274,22 @@ class BasicEncoder(nn.Module):
return x
class LargeEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(LargeEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
@@ -299,7 +310,7 @@ class LargeEncoder(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
@@ -314,7 +325,6 @@ class LargeEncoder(nn.Module):
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension

View File

@@ -4,54 +4,97 @@ import torch.nn.functional as F
from src.utils.flow_utils import warp
def resize(x, scale_factor):
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
nn.PReLU(out_channels)
def resize(x: torch.Tensor, scale_factor: float) -> torch.Tensor:
return F.interpolate(
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def convrelu(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
):
return nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=bias,
),
nn.PReLU(out_channels),
)
class ResBlock(nn.Module):
def __init__(self, in_channels, side_channels, bias=True):
super(ResBlock, self).__init__()
self.side_channels = side_channels
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(in_channels)
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv2 = nn.Sequential(
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(side_channels)
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(in_channels)
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv4 = nn.Sequential(
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
nn.PReLU(side_channels)
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv5 = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
)
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
self.prelu = nn.PReLU(in_channels)
def forward(self, x):
out = self.conv1(x)
res_feat = out[:, :-self.side_channels, ...]
side_feat = out[:, -self.side_channels:, :, :]
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv2(side_feat)
out = self.conv3(torch.cat([res_feat, side_feat], 1))
res_feat = out[:, :-self.side_channels, ...]
side_feat = out[:, -self.side_channels:, :, :]
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv4(side_feat)
out = self.conv5(torch.cat([res_feat, side_feat], 1))
out = self.prelu(x + out)
return out
class Encoder(nn.Module):
def __init__(self, channels, large=False):
super(Encoder, self).__init__()
@@ -59,30 +102,33 @@ class Encoder(nn.Module):
prev_ch = 3
for idx, ch in enumerate(channels, 1):
k = 7 if large and idx == 1 else 3
p = 3 if k ==7 else 1
self.register_module(f'pyramid{idx}',
nn.Sequential(
convrelu(prev_ch, ch, k, 2, p),
convrelu(ch, ch, 3, 1, 1)
))
p = 3 if k == 7 else 1
self.register_module(
f"pyramid{idx}",
nn.Sequential(
convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1)
),
)
prev_ch = ch
def forward(self, in_x):
fs = []
for idx in range(len(self.channels)):
out_x = getattr(self, f'pyramid{idx+1}')(in_x)
out_x = getattr(self, f"pyramid{idx + 1}")(in_x)
fs.append(out_x)
in_x = out_x
return fs
class InitDecoder(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch) -> None:
super().__init__()
self.convblock = nn.Sequential(
convrelu(in_ch*2+1, in_ch*2),
ResBlock(in_ch*2, skip_ch),
nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
convrelu(in_ch * 2 + 1, in_ch * 2),
ResBlock(in_ch * 2, skip_ch),
nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True),
)
def forward(self, f0, f1, embt):
h, w = f0.shape[2:]
embt = embt.repeat(1, 1, h, w)
@@ -91,14 +137,16 @@ class InitDecoder(nn.Module):
ft_ = out[:, 4:, ...]
return flow0, flow1, ft_
class IntermediateDecoder(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch) -> None:
super().__init__()
self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
convrelu(in_ch * 3 + 4, in_ch * 3),
ResBlock(in_ch * 3, skip_ch),
nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
f0_warp = warp(f0, flow0_in)
f1_warp = warp(f1, flow1_in)

View File

@@ -4,15 +4,17 @@ import torch.nn.functional as F
def resize(x, scale_factor):
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
return F.interpolate(
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def bilinear_sampler(img, coords, mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
def bilinear_sampler(img: torch.Tensor, coords: torch.Tensor, mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
@@ -25,27 +27,36 @@ def bilinear_sampler(img, coords, mask=False):
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device),
torch.arange(wd, device=device),
indexing='ij')
coords = torch.meshgrid(
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
)
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
class SmallUpdateBlock(nn.Module):
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
corr_levels=4, radius=3, scale_factor=None):
def __init__(
self,
cdim,
hidden_dim,
flow_dim,
corr_dim,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
):
super(SmallUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) **2
cor_planes = corr_levels * (2 * radius + 1) ** 2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(corr_dim + flow_dim, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
@@ -65,8 +76,9 @@ class SmallUpdateBlock(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = resize(net, 1 / self.scale_factor
) if self.scale_factor is not None else net
net = (
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr))
flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo))
@@ -80,26 +92,39 @@ class SmallUpdateBlock(nn.Module):
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
def __init__(
self,
cdim,
hidden_dim,
flow_dim,
corr_dim,
corr_dim2,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
out_num=1,
):
super(BasicUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) **2
cor_planes = corr_levels * (2 * radius + 1) ** 2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
@@ -113,14 +138,15 @@ class BasicUpdateBlock(nn.Module):
self.flow_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = resize(net, 1 / self.scale_factor
) if self.scale_factor is not None else net
net = (
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr))
cor = self.lrelu(self.convc2(cor))
flo = self.lrelu(self.convf1(flow))
@@ -135,38 +161,44 @@ class BasicUpdateBlock(nn.Module):
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow
class BidirCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
def __init__(
self, fmap1: torch.Tensor, fmap2: torch.Tensor, num_levels=4, radius=4
):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.corr_pyramid_T = []
self.corr_pyramid: list[torch.Tensor] = []
self.corr_pyramid_T: list[torch.Tensor] = []
corr = BidirCorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1)
self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T)
for _ in range(self.num_levels-1):
for _ in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T)
def __call__(self, coords0, coords1):
def __call__(self, coords0: torch.Tensor, coords1: torch.Tensor):
r = self.radius
coords0 = coords0.permute(0, 2, 3, 1)
coords1 = coords1.permute(0, 2, 3, 1)
assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
assert coords0.shape == coords1.shape, (
f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
)
batch, h1, w1, _ = coords0.shape
out_pyramid = []
@@ -175,15 +207,15 @@ class BidirCorrBlock:
corr = self.corr_pyramid[i]
corr_T = self.corr_pyramid_T[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
delta: torch.Tensor = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
delta_lvl: torch.Tensor = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
coords_lvl_0 = centroid_lvl_0 + delta_lvl
coords_lvl_1 = centroid_lvl_1 + delta_lvl
centroid_lvl_0: torch.Tensor = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
centroid_lvl_1: torch.Tensor = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
coords_lvl_0: torch.Tensor = centroid_lvl_0 + delta_lvl
coords_lvl_1: torch.Tensor = centroid_lvl_1 + delta_lvl
corr = bilinear_sampler(corr, coords_lvl_0)
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
@@ -194,14 +226,16 @@ class BidirCorrBlock:
out = torch.cat(out_pyramid, dim=-1)
out_T = torch.cat(out_pyramid_T, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(
0, 3, 1, 2
).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
def corr(fmap1: torch.Tensor, fmap2: torch.Tensor):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
return corr * (dim**-0.5)

View File

@@ -5,23 +5,26 @@ import numpy as np
def tensor2img(tensor: torch.Tensor):
return (
(tensor * 255.0)
.detach()
tensor = (
tensor.mul(255.0)
.clamp_(0, 255)
.to(torch.uint8)
.squeeze(0)
.permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3:
img = img[:, :, :3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:

View File

@@ -2,9 +2,10 @@ import os
import logging
import subprocess
from pathlib import Path
from typing import Generator, Iterable
import cv2
from typing import Generator
import numpy as np
class VideoMaker:
@@ -35,7 +36,7 @@ class VideoMaker:
with open(file, "w") as f:
for video in videos:
f.write(f"file '{video}'\n")
cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}"
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
logging.info(f"Running command: {cmd}")
result = self.run_command(cmd)
if result != 0:
@@ -66,7 +67,13 @@ class VideoMaker:
def run_command(self, cmd: str) -> int:
try:
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
subprocess.run(
cmd,
shell=True,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return 0
except subprocess.CalledProcessError as e:
logging.error(f"Command failed with error: {e}")
@@ -74,7 +81,7 @@ class VideoMaker:
def video_to_frames_generator(
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
) -> Generator[tuple[Path, ...], None, None]:
) -> Generator[tuple[np.ndarray, ...], None, None]:
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
cap = cv2.VideoCapture(str(video_path))
@@ -85,21 +92,56 @@ class VideoMaker:
fps = cap.get(cv2.CAP_PROP_FPS)
frames_per_chunk = int(fps * chunk_seconds)
frame_index = 0
while True:
paths = []
for _ in range(frames_per_chunk):
ret, frame = cap.read()
if not ret:
cap.release()
return
frame_path = output_dir / f"img_{frame_index:08d}.png"
cv2.imwrite(str(frame_path), frame)
paths.append(frame_path)
frame_index += 1
paths.append(frame)
yield tuple(paths)
def images_to_video_pipeline(
self,
frames: Iterable[np.ndarray],
output_path: Path,
width: int,
height: int,
fps: float,
):
pipeline = subprocess.Popen(
[
"ffmpeg",
"-y",
"-f", "rawvideo",
"-vcodec", "rawvideo",
"-pix_fmt", "bgr24",
"-s", f"{width}x{height}",
"-r", str(fps),
"-i", "-",
"-an",
"-vcodec", "libx264",
"-pix_fmt", "yuv420p",
str(output_path),
],
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL
)
if pipeline.stdin is None:
raise Exception("STDIN closed")
for frame in frames:
pipeline.stdin.write(frame.tobytes())
pipeline.stdin.close()
pipeline.wait()
def get_size(self, video_path: Path) -> tuple[int, int]:
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
return width, height

2242
uv.lock generated

File diff suppressed because it is too large Load Diff