10 Commits

Author SHA1 Message Date
Viner Abubakirov
0c871c2314 Refactor image tensor conversion and model inference in interpolator.py and torch.py 2026-04-13 18:56:00 +05:00
Viner Abubakirov
61f8e0abe1 Поменял, метод формирования видео 2026-04-03 17:03:10 +05:00
Viner Abubakirov
faf7aa8e81 Начал менять логику pipeline 2026-04-02 18:26:36 +05:00
Viner Abubakirov
bc09cd7b6c Убрал спам от ffmpeg 2026-04-02 11:53:02 +05:00
Viner Abubakirov
be794539ac Обновил video.py
Команда выполняющейся  при склейке видео по разному ввели в (sh, bash, zsh)
2026-04-02 10:47:41 +05:00
Viner Abubakirov
28e51d1c5e Забыл удалить переменные из runner 2026-04-02 10:17:17 +05:00
Viner Abubakirov
97ca8b19f8 Добавил argument parser 2026-04-02 10:16:36 +05:00
Viner Abubakirov
4fc13db0e8 Переместил interpolator.py внутрь src | добавил пресеты | добавил новые модели 2026-04-02 10:06:06 +05:00
Viner Abubakirov
c984b38904 Обновил main.py, добавил fs.py и video.py 2026-04-01 23:41:00 +05:00
Viner Abubakirov
888cdb3151 Перенес networks внутрь src 2026-04-01 21:41:05 +05:00
25 changed files with 849 additions and 706 deletions

1
.gitignore vendored
View File

@@ -175,5 +175,6 @@ cython_debug/
.pypirc .pypirc
.DS_Store
source/ source/
output/ output/

415
main.py
View File

@@ -1,149 +1,28 @@
import logging import logging
import subprocess
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import cv2 from cv2 import imwrite
from tqdm import tqdm import tqdm
from time import perf_counter
from decimal import Decimal
from interpolator import get_device from src.config import presets
from interpolator import ImageInterpolator from src.utils.fs import FileSystem
from interpolator import ModelRunner, Anchor from src.utils.video import VideoMaker
from src.interpolator import (
ImageInterpolator,
logging.basicConfig( Anchor,
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" get_device,
get_vram_available,
ModelRunner,
) )
from pathlib import Path
if TYPE_CHECKING:
import torch
import numpy as np
def move_images(src_dir: str, interpolated_dir: str, output_dir: str): def performing_warning_message(device: "torch.device"):
src_dir = Path(src_dir)
interpolated_dir = Path(interpolated_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
index = 0
src_frames = sorted(src_dir.glob("img_*.png"))
interp_frames = sorted(interpolated_dir.glob("img_*.png"))
for i in range(len(src_frames)):
output_frame = output_dir / f"img_{index:08d}.png"
src_frames[i].rename(output_frame)
index += 1
if i < len(interp_frames):
output_interp = output_dir / f"img_{index:08d}.png"
interp_frames[i].rename(output_interp)
index += 1
def build_file_list(moved_dir: str, list_path: str):
import os
moved_dir = Path(moved_dir)
frames = sorted(moved_dir.glob("img_*.png"))
print(frames[0])
with open(list_path, "w") as f:
for frame in frames:
f.write(f"file '{os.path.abspath(frame)}'\n")
def build_ffmpeg_file_list(frames_dir: str, interpolated_dir: str, list_path: str):
frames = sorted(Path(frames_dir).glob("img_*.png"))
interps = sorted(Path(interpolated_dir).glob("img_*.png"))
if len(interps) != len(frames) - 1:
raise ValueError("Interpolated frames must be N-1")
with open(list_path, "w") as f:
for i in range(len(frames)):
f.write(f"file '{frames[i].resolve().as_posix()}'\n")
if i < len(interps):
f.write(f"file '{interps[i].resolve().as_posix()}'\n")
def merge_with_ffmpeg(
original_video: str,
file_list: str,
output_video: str,
):
cap = cv2.VideoCapture(original_video)
if not cap.isOpened():
raise ValueError("Cannot open original video")
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
new_fps = Decimal(fps * 2)
cmd = [
"ffmpeg",
"-y",
"-r", str(new_fps.quantize(Decimal("1.0000000000"))),
"-f", "concat",
"-safe", "0",
"-i", file_list,
"-c:v", "libx264rgb",
output_video,
]
print("Running ffmpeg command:", " ".join(cmd))
subprocess.run(cmd, check=True)
def video_frames_to_disk_generator(
video_path: str | Path,
output_dir: str | Path,
chunk_seconds: int = 10
):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frames_per_chunk = int(fps * chunk_seconds)
frame_index = 0
while True:
paths = []
for _ in range(frames_per_chunk):
ret, frame = cap.read()
if not ret:
cap.release()
return
frame_path = output_dir / f"img_{frame_index:08d}.png"
cv2.imwrite(str(frame_path), frame)
paths.append(frame_path)
frame_index += 1
yield tuple(paths)
def main():
start = perf_counter()
logging.info("Starting video interpolation process")
config_path = Path("src/config/AMT-G.yaml")
ckpt_path = Path("src/pretrained/amt-g.pth")
video_path = Path("example/video.mp4")
output_dir = Path("output/frames")
output_interpolated_dir = Path("output/interpolated")
output_interpolated_dir.mkdir(parents=True, exist_ok=True)
device = get_device()
model_runner = ModelRunner(config_path, ckpt_path, device)
if device.type in ("cpu", "mps"): if device.type in ("cpu", "mps"):
if device.type == "mps": if device.type == "mps":
logging.warning( logging.warning(
@@ -153,87 +32,203 @@ def main():
logging.warning( logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance." "Running on CPU may be very slow. Consider using a GPU for better performance."
) )
anchor = Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda": elif device.type == "cuda":
anchor = Anchor( pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
) )
else: else:
raise Exception(f"Unsupported device type: {device.type}") raise Exception(f"Unsupported device type: {device.type}")
interpolator = ImageInterpolator(device, anchor, model_runner)
loaded_time = perf_counter() - start
logging.info(f"Model loaded and initialized in {loaded_time:.2f} seconds")
prev_frame_path = None
frame_count = 0
for frame_paths in video_frames_to_disk_generator(video_path, output_dir):
logging.info(f"Processing frames: {len(frame_paths)}")
if prev_frame_path is not None:
img1 = prev_frame_path[-1]
img2 = frame_paths[0]
output_path = output_interpolated_dir / f"img_{frame_count:08d}.png"
interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
frame_count += 1
for i in tqdm(range(len(frame_paths) - 1), desc="Interpolating frames"):
img1 = frame_paths[i]
img2 = frame_paths[i + 1]
output_path = output_interpolated_dir / f"img_{frame_count:08d}.png"
interpolator.interpolate(img1, img2, output_path)
logging.debug(f"Interpolated image saved to: {output_path}")
frame_count += 1
prev_frame_path = frame_paths
total_time = perf_counter() - start
logging.info(f"Video interpolation completed in {total_time:.2f} seconds")
def builder(): def init_model_runner(
frames_dir = "output/frames" config: Path, checkpoint_path: Path, device: "torch.device"
interpolated_dir = "output/interpolated" ) -> ModelRunner:
moved_dir = "output/moved" return ModelRunner(config, checkpoint_path, device)
video_path = "example/video.mp4"
output_video = "output/interpolated_video.mp4"
move_images(frames_dir, interpolated_dir, moved_dir)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Cannot open original video")
fps = cap.get(cv2.CAP_PROP_FPS)
cmd = [
"ffmpeg",
"-y",
"-framerate", str(fps * 2),
"-i", f"{moved_dir}/img_%08d.png",
"-i", video_path,
"-c:v", "libx264",
"-c:a", "copy",
"-shortest",
output_video,
]
logging.info("Running ffmpeg command to build final video: " + " ".join(cmd))
subprocess.run(cmd, check=True)
def cleanup(): def init_interpolator(
import os model_runner: ModelRunner, device: "torch.device"
import shutil ) -> ImageInterpolator:
frames_dir = "output/frames" anchor = init_anchor(device)
interpolated_dir = "output/interpolated" return ImageInterpolator(device, anchor, model_runner)
moved_dir = "output/moved"
os.makedirs(frames_dir, exist_ok=True)
os.makedirs(interpolated_dir, exist_ok=True) class InterpolationPipeline:
os.makedirs(moved_dir, exist_ok=True) def __init__(
shutil.rmtree(frames_dir) self,
shutil.rmtree(interpolated_dir) config: Path,
shutil.rmtree(moved_dir) checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
part += 1
for i in tqdm.tqdm(
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)
def main():
import argparse
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--base_path", help="Base path", default="output")
parser.add_argument(
"-v", "--video_path", help="Video path", default="example/video.mp4"
)
parser.add_argument(
"-o",
"--output",
help="Output video name (example: 'interpolated_video.mp4')",
default="interpolated_video.mp4",
)
parser.add_argument(
"-p",
"--preset",
help="Model preset",
choices=["small", "large", "global"],
default="global",
)
args = parser.parse_args()
runner(
base_path=Path(args.base_path),
video_path=Path(args.video_path),
output_video=args.output,
preset=getattr(presets, args.preset.upper()),
)
if __name__ == "__main__": if __name__ == "__main__":
cleanup()
main() main()
builder()
cleanup()

View File

@@ -1,69 +0,0 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from networks.blocks.ifrnet import (
convrelu, resize,
ResBlock,
)
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
mask=None, img_res=None, mean=None):
'''
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
'''
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = mask.reshape(b, num_flows, 1, h, w
).reshape(-1, 1, h, w) if mask is not None else None
img_res = img_res.reshape(b, num_flows, 3, h, w
).reshape(-1, 3, h, w) if img_res is not None else 0
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
) if mean is not None else 0
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: networks.AMT-G.Model name: src.networks.AMT-G.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

62
src/config/AMT-L.yaml Normal file
View File

@@ -0,0 +1,62 @@
exp_name: floloss1e-2_300epoch_bs24_lr2e-4
seed: 2023
epochs: 300
distributed: true
lr: 2e-4
lr_min: 2e-5
weight_decay: 0.0
resume_state: null
save_dir: work_dir
eval_interval: 1
network:
name: src.networks.AMT-L.Model
params:
corr_radius: 3
corr_lvls: 4
num_flows: 5
data:
train:
name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset
params:
dataset_dir: data/vimeo_triplet
val:
name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset
params:
dataset_dir: data/vimeo_triplet
train_loader:
batch_size: 24
num_workers: 12
val_loader:
batch_size: 24
num_workers: 3
logger:
use_wandb: true
resume_id: null
losses:
- {
name: losses.loss.CharbonnierLoss,
nickname: l_rec,
params: {
loss_weight: 1.0,
keys: [imgt_pred, imgt]
}
}
- {
name: losses.loss.TernaryLoss,
nickname: l_ter,
params: {
loss_weight: 1.0,
keys: [imgt_pred, imgt]
}
}
- {
name: losses.loss.MultipleFlowLoss,
nickname: l_flo,
params: {
loss_weight: 0.002,
keys: [flow0_pred, flow1_pred, flow]
}
}

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: networks.AMT-S.Model name: src.networks.AMT-S.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

24
src/config/presets.py Normal file
View File

@@ -0,0 +1,24 @@
from pathlib import Path
from dataclasses import dataclass
@dataclass(frozen=True)
class Preset:
config: Path
checkpoint: Path
SMALL = Preset(
config=Path("src/config/AMT-S.yaml"),
checkpoint=Path("src/pretrained/amt-s.pth"),
)
LARGE = Preset(
config=Path("src/config/AMT-L.yaml"),
checkpoint=Path("src/pretrained/amt-l.pth"),
)
GLOBAL = Preset(
config=Path("src/config/AMT-G.yaml"),
checkpoint=Path("src/pretrained/amt-g.pth"),
)

View File

@@ -1,12 +1,12 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
from src.utils import utils from src.utils.torch import img2tensor, tensor2img
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder from src.utils.padder import InputPadder
@@ -40,7 +40,7 @@ class ModelRunner:
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device()) model = model.to(get_device())
model.eval() model.eval()
self.model = model self.model = torch.compile(model)
def get_vram_available(device: torch.device) -> int: def get_vram_available(device: torch.device) -> int:
@@ -83,40 +83,30 @@ class ImageInterpolator:
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes" f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
) )
def interpolate(self, image1: Path, image2: Path, output_path: Path): def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""
Interpolates between two images and saves the result.
Args:
image1 (Path): Path to the first input image (only png and jpg formats are supported)
image2 (Path): Path to the second input image (only png and jpg formats are supported)
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}") logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(utils.read(image1)).to(self.device) tensor1 = img2tensor(image1, self.device)
tensor2 = img2tensor(utils.read(image2)).to(self.device) tensor2 = img2tensor(image2, self.device)
logging.debug( logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
) )
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation") logging.debug("Running model inference for interpolation")
with torch.no_grad(): with torch.no_grad():
interpolated = self.model_runner.model( with torch.amp.autocast(self.device.type):
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True interpolated = self.model_runner.model(
)["imgt_pred"] tensor1, tensor2, self.embt
)["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
utils.write(output_path, tensor2img(interpolated.cpu())) return tensor2img(interpolated.cpu())
logging.debug(f"Saved interpolated image to: {output_path}")
def scale(self, height: int, width: int) -> float: def scale(self, height: int, width: int) -> float:
scale = ( scale = (

View File

@@ -1,9 +1,11 @@
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from networks.blocks.feat_enc import LargeEncoder from src.networks.blocks.feat_enc import LargeEncoder
from networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module): class Model(nn.Module):
@@ -42,7 +44,7 @@ class Model(nn.Module):
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
) )
def _get_updateblock(self, cdim, scale_factor=None): def _get_updateblock(self, cdim: int, scale_factor: Optional[float] = None):
return BasicUpdateBlock( return BasicUpdateBlock(
cdim=cdim, cdim=cdim,
hidden_dim=192, hidden_dim=192,
@@ -55,7 +57,15 @@ class Model(nn.Module):
radius=self.radius, radius=self.radius,
) )
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): def _corr_scale_lookup(
self,
corr_fn: BidirCorrBlock,
coord: torch.Tensor,
flow0: torch.Tensor,
flow1: torch.Tensor,
embt: torch.Tensor,
downsample: int = 1,
):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption # based on linear assumption
t1_scale = 1.0 / embt t1_scale = 1.0 / embt
@@ -70,7 +80,15 @@ class Model(nn.Module):
flow = torch.cat([flow0, flow1], dim=1) flow = torch.cat([flow0, flow1], dim=1)
return corr, flow return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(
self,
img0: torch.Tensor,
img1: torch.Tensor,
embt: torch.Tensor,
scale_factor: float = 1.0,
eval: bool = False,
**kwargs,
):
mean_ = ( mean_ = (
torch.cat([img0, img1], 2) torch.cat([img0, img1], 2)
.mean(1, keepdim=True) .mean(1, keepdim=True)

View File

@@ -1,40 +1,31 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from networks.blocks.raft import ( from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
coords_grid, from src.networks.blocks.feat_enc import BasicEncoder
BasicUpdateBlock, BidirCorrBlock from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
)
from networks.blocks.feat_enc import ( from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
BasicEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, def __init__(
corr_radius=3, self,
corr_lvls=4, corr_radius=3,
num_flows=5, corr_lvls=4,
channels=[48, 64, 72, 128], num_flows=5,
skip_channels=48 channels=[48, 64, 72, 128],
): skip_channels=48,
):
super(Model, self).__init__() super(Model, self).__init__()
self.radius = corr_radius self.radius = corr_radius
self.corr_levels = corr_lvls self.corr_levels = corr_lvls
self.num_flows = num_flows self.num_flows = num_flows
self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) self.feat_encoder = BasicEncoder(
output_dim=128, norm_fn="instance", dropout=0.0
)
self.encoder = Encoder([48, 64, 72, 128], large=True) self.encoder = Encoder([48, 64, 72, 128], large=True)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
@@ -43,45 +34,59 @@ class Model(nn.Module):
self.update4 = self._get_updateblock(72, None) self.update4 = self._get_updateblock(72, None)
self.update3 = self._get_updateblock(64, 2.0) self.update3 = self._get_updateblock(64, 2.0)
self.update2 = self._get_updateblock(48, 4.0) self.update2 = self._get_updateblock(48, 4.0)
self.comb_block = nn.Sequential( self.comb_block = nn.Sequential(
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
nn.PReLU(6*self.num_flows), nn.PReLU(6 * self.num_flows),
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
) )
def _get_updateblock(self, cdim, scale_factor=None): def _get_updateblock(self, cdim, scale_factor=None):
return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48, return BasicUpdateBlock(
corr_dim=256, corr_dim2=160, fc_dim=124, cdim=cdim,
scale_factor=scale_factor, corr_levels=self.corr_levels, hidden_dim=128,
radius=self.radius) flow_dim=48,
corr_dim=256,
corr_dim2=160,
fc_dim=124,
scale_factor=scale_factor,
corr_levels=self.corr_levels,
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption # based on linear assumption
t1_scale = 1. / embt t1_scale = 1.0 / embt
t0_scale = 1. / (1. - embt) t0_scale = 1.0 / (1.0 - embt)
if downsample != 1: if downsample != 1:
inv = 1 / downsample inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv) flow0 = inv * resize(flow0, scale_factor=inv)
flow1 = inv * resize(flow1, scale_factor=inv) flow1 = inv * resize(flow1, scale_factor=inv)
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
corr = torch.cat([corr0, corr1], dim=1) corr = torch.cat([corr0, corr1], dim=1)
flow = torch.cat([flow0, flow1], dim=1) flow = torch.cat([flow0, flow1], dim=1)
return corr, flow return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_ img0 = img0 - mean_
img1 = img1 - mean_ img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
b, _, h, w = img0_.shape b, _, h, w = img0_.shape
coord = coords_grid(b, h // 8, w // 8, img0.device) coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) corr_fn = BidirCorrBlock(
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
@@ -90,9 +95,9 @@ class Model(nn.Module):
######################################### the 4th decoder ######################################### ######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, corr_4, flow_4 = self._corr_scale_lookup(
up_flow0_4, up_flow1_4, corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
embt, downsample=1) )
# residue update with lookup corr # residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
@@ -102,10 +107,12 @@ class Model(nn.Module):
ft_3_ = ft_3_ + delta_ft_3_ ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder ######################################### ######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
corr_3, flow_3 = self._corr_scale_lookup(corr_fn, ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
coord, up_flow0_3, up_flow1_3, )
embt, downsample=2) corr_3, flow_3 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
@@ -115,11 +122,13 @@ class Model(nn.Module):
ft_2_ = ft_2_ + delta_ft_2_ ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder ######################################### ######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
corr_2, flow_2 = self._corr_scale_lookup(corr_fn, ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
coord, up_flow0_2, up_flow1_2, )
embt, downsample=4) corr_2, flow_2 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
@@ -128,28 +137,36 @@ class Model(nn.Module):
ft_1_ = ft_1_ + delta_ft_1_ ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder ######################################### ######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
if scale_factor != 1.0: )
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
mask = resize(mask, scale_factor=(1.0/scale_factor))
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
# Merge multiple predictions if scale_factor != 1.0:
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
mask, img_res, mean_) 1.0 / scale_factor
)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions
imgt_pred = multi_flow_combine(
self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
)
imgt_pred = torch.clamp(imgt_pred, 0, 1) imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval: if eval:
return { 'imgt_pred': imgt_pred, } return {
"imgt_pred": imgt_pred,
}
else: else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return { return {
'imgt_pred': imgt_pred, "imgt_pred": imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_], "ft_pred": [ft_1_, ft_2_, ft_3_],
} }

View File

@@ -1,31 +1,20 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from networks.blocks.raft import ( from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
coords_grid, from src.networks.blocks.feat_enc import SmallEncoder
SmallUpdateBlock, BidirCorrBlock from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
) from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from networks.blocks.feat_enc import (
SmallEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, def __init__(
corr_radius=3, self,
corr_lvls=4, corr_radius=3,
num_flows=3, corr_lvls=4,
channels=[20, 32, 44, 56], num_flows=3,
skip_channels=20): channels=[20, 32, 44, 56],
skip_channels=20,
):
super(Model, self).__init__() super(Model, self).__init__()
self.radius = corr_radius self.radius = corr_radius
self.corr_levels = corr_lvls self.corr_levels = corr_lvls
@@ -33,7 +22,7 @@ class Model(nn.Module):
self.channels = channels self.channels = channels
self.skip_channels = skip_channels self.skip_channels = skip_channels
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) self.feat_encoder = SmallEncoder(output_dim=84, norm_fn="instance", dropout=0.0)
self.encoder = Encoder(channels) self.encoder = Encoder(channels)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
@@ -44,44 +33,58 @@ class Model(nn.Module):
self.update4 = self._get_updateblock(44) self.update4 = self._get_updateblock(44)
self.update3 = self._get_updateblock(32, 2) self.update3 = self._get_updateblock(32, 2)
self.update2 = self._get_updateblock(20, 4) self.update2 = self._get_updateblock(20, 4)
self.comb_block = nn.Sequential( self.comb_block = nn.Sequential(
nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1), nn.Conv2d(3 * num_flows, 6 * num_flows, 3, 1, 1),
nn.PReLU(6*num_flows), nn.PReLU(6 * num_flows),
nn.Conv2d(6*num_flows, 3, 3, 1, 1), nn.Conv2d(6 * num_flows, 3, 3, 1, 1),
) )
def _get_updateblock(self, cdim, scale_factor=None): def _get_updateblock(self, cdim, scale_factor=None):
return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64, return SmallUpdateBlock(
fc_dim=68, scale_factor=scale_factor, cdim=cdim,
corr_levels=self.corr_levels, radius=self.radius) hidden_dim=76,
flow_dim=20,
corr_dim=64,
fc_dim=68,
scale_factor=scale_factor,
corr_levels=self.corr_levels,
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption # based on linear assumption
t1_scale = 1. / embt t1_scale = 1.0 / embt
t0_scale = 1. / (1. - embt) t0_scale = 1.0 / (1.0 - embt)
if downsample != 1: if downsample != 1:
inv = 1 / downsample inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv) flow0 = inv * resize(flow0, scale_factor=inv)
flow1 = inv * resize(flow1, scale_factor=inv) flow1 = inv * resize(flow1, scale_factor=inv)
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
corr = torch.cat([corr0, corr1], dim=1) corr = torch.cat([corr0, corr1], dim=1)
flow = torch.cat([flow0, flow1], dim=1) flow = torch.cat([flow0, flow1], dim=1)
return corr, flow return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_ img0 = img0 - mean_
img1 = img1 - mean_ img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
b, _, h, w = img0_.shape b, _, h, w = img0_.shape
coord = coords_grid(b, h // 8, w // 8, img0.device) coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) corr_fn = BidirCorrBlock(
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
@@ -90,9 +93,9 @@ class Model(nn.Module):
######################################### the 4th decoder ######################################### ######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, corr_4, flow_4 = self._corr_scale_lookup(
up_flow0_4, up_flow1_4, corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
embt, downsample=1) )
# residue update with lookup corr # residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
@@ -102,10 +105,12 @@ class Model(nn.Module):
ft_3_ = ft_3_ + delta_ft_3_ ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder ######################################### ######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
corr_3, flow_3 = self._corr_scale_lookup(corr_fn, ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
coord, up_flow0_3, up_flow1_3, )
embt, downsample=2) corr_3, flow_3 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
@@ -115,11 +120,13 @@ class Model(nn.Module):
ft_2_ = ft_2_ + delta_ft_2_ ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder ######################################### ######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
corr_2, flow_2 = self._corr_scale_lookup(corr_fn, ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
coord, up_flow0_2, up_flow1_2, )
embt, downsample=4) corr_2, flow_2 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
@@ -128,27 +135,36 @@ class Model(nn.Module):
ft_1_ = ft_1_ + delta_ft_1_ ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder ######################################### ######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
if scale_factor != 1.0: )
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) if scale_factor != 1.0:
mask = resize(mask, scale_factor=(1.0/scale_factor)) up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
img_res = resize(img_res, scale_factor=(1.0/scale_factor)) 1.0 / scale_factor
)
# Merge multiple predictions up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, 1.0 / scale_factor
mask, img_res, mean_) )
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions
imgt_pred = multi_flow_combine(
self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
)
imgt_pred = torch.clamp(imgt_pred, 0, 1) imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval: if eval:
return { 'imgt_pred': imgt_pred, } return {
"imgt_pred": imgt_pred,
}
else: else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return { return {
'imgt_pred': imgt_pred, "imgt_pred": imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_], "ft_pred": [ft_1_, ft_2_, ft_3_],
} }

View File

@@ -1,32 +1,25 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.utils.flow_utils import warp from src.utils.flow_utils import warp
from networks.blocks.ifrnet import ( from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
convrelu, resize,
ResBlock,
)
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self): def __init__(self):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.pyramid1 = nn.Sequential( self.pyramid1 = nn.Sequential(
convrelu(3, 32, 3, 2, 1), convrelu(3, 32, 3, 2, 1), convrelu(32, 32, 3, 1, 1)
convrelu(32, 32, 3, 1, 1)
) )
self.pyramid2 = nn.Sequential( self.pyramid2 = nn.Sequential(
convrelu(32, 48, 3, 2, 1), convrelu(32, 48, 3, 2, 1), convrelu(48, 48, 3, 1, 1)
convrelu(48, 48, 3, 1, 1)
) )
self.pyramid3 = nn.Sequential( self.pyramid3 = nn.Sequential(
convrelu(48, 72, 3, 2, 1), convrelu(48, 72, 3, 2, 1), convrelu(72, 72, 3, 1, 1)
convrelu(72, 72, 3, 1, 1)
) )
self.pyramid4 = nn.Sequential( self.pyramid4 = nn.Sequential(
convrelu(72, 96, 3, 2, 1), convrelu(72, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1)
convrelu(96, 96, 3, 1, 1)
) )
def forward(self, img): def forward(self, img):
f1 = self.pyramid1(img) f1 = self.pyramid1(img)
f2 = self.pyramid2(f1) f2 = self.pyramid2(f1)
@@ -39,11 +32,11 @@ class Decoder4(nn.Module):
def __init__(self): def __init__(self):
super(Decoder4, self).__init__() super(Decoder4, self).__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(192+1, 192), convrelu(192 + 1, 192),
ResBlock(192, 32), ResBlock(192, 32),
nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True) nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True),
) )
def forward(self, f0, f1, embt): def forward(self, f0, f1, embt):
b, c, h, w = f0.shape b, c, h, w = f0.shape
embt = embt.repeat(1, 1, h, w) embt = embt.repeat(1, 1, h, w)
@@ -56,9 +49,9 @@ class Decoder3(nn.Module):
def __init__(self): def __init__(self):
super(Decoder3, self).__init__() super(Decoder3, self).__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(220, 216), convrelu(220, 216),
ResBlock(216, 32), ResBlock(216, 32),
nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True) nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True),
) )
def forward(self, ft_, f0, f1, up_flow0, up_flow1): def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -73,9 +66,9 @@ class Decoder2(nn.Module):
def __init__(self): def __init__(self):
super(Decoder2, self).__init__() super(Decoder2, self).__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(148, 144), convrelu(148, 144),
ResBlock(144, 32), ResBlock(144, 32),
nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True) nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True),
) )
def forward(self, ft_, f0, f1, up_flow0, up_flow1): def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -90,11 +83,11 @@ class Decoder1(nn.Module):
def __init__(self): def __init__(self):
super(Decoder1, self).__init__() super(Decoder1, self).__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(100, 96), convrelu(100, 96),
ResBlock(96, 32), ResBlock(96, 32),
nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True) nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True),
) )
def forward(self, ft_, f0, f1, up_flow0, up_flow1): def forward(self, ft_, f0, f1, up_flow0, up_flow1):
f0_warp = warp(f0, up_flow0) f0_warp = warp(f0, up_flow0)
f1_warp = warp(f1, up_flow1) f1_warp = warp(f1, up_flow1)
@@ -113,13 +106,18 @@ class Model(nn.Module):
self.decoder1 = Decoder1() self.decoder1 = Decoder1()
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_ img0 = img0 - mean_
img1 = img1 - mean_ img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
@@ -143,13 +141,17 @@ class Model(nn.Module):
up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
up_mask_1 = torch.sigmoid(out1[:, 4:5]) up_mask_1 = torch.sigmoid(out1[:, 4:5])
up_res_1 = out1[:, 5:] up_res_1 = out1[:, 5:]
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 1.0 / scale_factor
up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) )
up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor))
up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor))
img0_warp = warp(img0, up_flow0_1) img0_warp = warp(img0, up_flow0_1)
img1_warp = warp(img1, up_flow1_1) img1_warp = warp(img1, up_flow1_1)
imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
@@ -157,13 +159,15 @@ class Model(nn.Module):
imgt_pred = torch.clamp(imgt_pred, 0, 1) imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval: if eval:
return { 'imgt_pred': imgt_pred, } return {
"imgt_pred": imgt_pred,
}
else: else:
return { return {
'imgt_pred': imgt_pred, "imgt_pred": imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_], "ft_pred": [ft_1_, ft_2_, ft_3_],
'img0_warp': img0_warp, "img0_warp": img0_warp,
'img1_warp': img1_warp "img1_warp": img1_warp,
} }

View File

@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
"""
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
"""
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch * 3 + 4, in_ch * 3),
ResBlock(in_ch * 3, skip_ch),
nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
return flow0, flow1, mask, img_res

BIN
src/pretrained/amt-l.pth Normal file

Binary file not shown.

BIN
src/pretrained/amt-s.pth Normal file

Binary file not shown.

View File

@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
import importlib import importlib
from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from omegaconf import DictConfig from omegaconf import DictConfig

53
src/utils/fs.py Normal file
View File

@@ -0,0 +1,53 @@
from pathlib import Path
class FileSystem:
SOURCE_PATH = "source"
OUTPUT_PATH = "output"
FRAMES_PATH = "frames"
INTERPOLATED_PATH = "interpolated"
MOVED_PATH = "moved"
VIDEO_PART_PATH = "video_parts"
def __init__(self, base_path: Path):
self.base_path = base_path
self.base_path.mkdir(parents=True, exist_ok=True)
def create_directory(self, dir_name: str) -> Path:
"""Creates a directory under the base path."""
dir_path = self.base_path / dir_name
dir_path.mkdir(parents=True, exist_ok=True)
return dir_path
def clear_directory(self, dir_path: Path):
"""Clears all files in the specified directory."""
for item in dir_path.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
self.clear_directory(item)
item.rmdir()
@property
def source_path(self) -> Path:
return self.create_directory(self.SOURCE_PATH)
@property
def output_path(self) -> Path:
return self.create_directory(self.OUTPUT_PATH)
@property
def frames_path(self) -> Path:
return self.create_directory(self.FRAMES_PATH)
@property
def interpolated_path(self) -> Path:
return self.create_directory(self.INTERPOLATED_PATH)
@property
def moved_path(self) -> Path:
return self.create_directory(self.MOVED_PATH)
@property
def video_part_path(self) -> Path:
return self.create_directory(self.VIDEO_PART_PATH)

View File

@@ -5,23 +5,26 @@ import numpy as np
def tensor2img(tensor: torch.Tensor): def tensor2img(tensor: torch.Tensor):
return ( tensor = (
(tensor * 255.0) tensor.mul(255.0)
.detach() .clamp_(0, 255)
.to(torch.uint8)
.squeeze(0) .squeeze(0)
.permute(1, 2, 0) .permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
) )
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor") logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3: if img.shape[-1] > 3:
img = img[:, :, :3] img = img[:, :, :3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:

View File

@@ -1,199 +0,0 @@
import re
import sys
from pathlib import Path
import numpy as np
from imageio import imread, imwrite
def read(file: Path) -> np.ndarray:
readers = {
".float3": readFloat,
".flo": readFlow,
".ppm": readImage,
".pgm": readImage,
".png": readImage,
".jpg": readImage,
".pfm": lambda f: readPFM(f)[0],
}
func = readers.get(file.suffix.lower())
if func is None:
raise Exception("don't know how to read %s" % file)
return func(file)
def write(file: Path, data: np.ndarray) -> None:
writers = {
".float3": writeFloat,
".flo": writeFlow,
".ppm": writeImage,
".pgm": writeImage,
".png": writeImage,
".jpg": writeImage,
".pfm": writePFM,
}
func = writers.get(file.suffix.lower())
if func is None:
raise Exception("don't know how to write %s" % file)
return func(file, data)
def readPFM(file: Path):
data = open(file, "rb")
color = None
width = None
height = None
scale = None
endian = None
header = data.readline().rstrip()
if header.decode("ascii") == "PF":
color = True
elif header.decode("ascii") == "Pf":
color = False
else:
raise Exception("Not a PFM file.")
dim_match = re.match(r"^(\d+)\s(\d+)\s$", data.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception("Malformed PFM header.")
scale = float(data.readline().decode("ascii").rstrip())
if scale < 0:
endian = "<"
scale = -scale
else:
endian = ">"
result = np.fromfile(data, endian + "f")
shape = (height, width, 3) if color else (height, width)
result = np.reshape(result, shape)
result = np.flipud(result)
return result, scale
def writePFM(file: Path, image: np.ndarray, scale=1):
data = open(file, "wb")
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3:
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
data.write("PF\n" if color else "Pf\n".encode()) # type: ignore
data.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
data.write("%f\n".encode() % scale)
image.tofile(data)
def readFlow(file: Path):
if file.suffix.lower() == ".pfm":
return readPFM(file)[0][:, :, 0:2]
f = open(file, "rb")
header = f.read(4)
if header.decode("utf-8") != "PIEH":
raise Exception("Flow file header does not contain PIEH")
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
return flow.astype(np.float32)
def readImage(file: Path):
if file.suffix.lower() == ".pfm":
data = readPFM(file)[0]
if len(data.shape) == 3:
return data[:, :, 0:3]
else:
return data
return imread(file)
def writeImage(file: Path, data: np.ndarray):
if file.suffix.lower() == ".pfm":
return writePFM(file, data, 1)
return imwrite(file, data)
def writeFlow(file: Path, flow: np.ndarray):
f = open(file, "wb")
f.write("PIEH".encode("utf-8"))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def readFloat(file: Path):
f = open(file, "rb")
if (f.readline().decode("utf-8")) != "float\n":
raise Exception("float file %s did not contain <float> keyword" % file)
dim = int(f.readline())
dims = []
count = 1
for _ in range(0, dim):
d = int(f.readline())
dims.append(d)
count *= d
dims = list(reversed(dims))
data = np.fromfile(f, np.float32, count).reshape(dims)
if dim > 2:
data = np.transpose(data, (2, 1, 0))
data = np.transpose(data, (1, 0, 2))
return data
def writeFloat(file: Path, data: np.ndarray):
f = open(file, "wb")
dim = len(data.shape)
if dim > 3:
raise Exception("bad float file dimension: %d" % dim)
f.write(("float\n").encode("ascii"))
f.write(("%d\n" % dim).encode("ascii"))
if dim == 1:
f.write(("%d\n" % data.shape[0]).encode("ascii"))
else:
f.write(("%d\n" % data.shape[1]).encode("ascii"))
f.write(("%d\n" % data.shape[0]).encode("ascii"))
for i in range(2, dim):
f.write(("%d\n" % data.shape[i]).encode("ascii"))
data = data.astype(np.float32)
if dim == 2:
data.tofile(f)
else:
np.transpose(data, (2, 0, 1)).tofile(f)

147
src/utils/video.py Normal file
View File

@@ -0,0 +1,147 @@
import os
import logging
import subprocess
from pathlib import Path
from typing import Generator, Iterable
import cv2
import numpy as np
class VideoMaker:
def images_to_video(
self,
images_path: Path,
output_path: Path,
fps: float,
image_numerator: str = "img_%08d.png",
):
"""Converts a sequence of images to a video using ffmpeg."""
cmd = f"ffmpeg -framerate {fps} -i {images_path / image_numerator} -c:v libx264 -pix_fmt yuv420p {output_path}"
logging.info(f"Running command: {cmd}")
result = self.run_command(cmd)
if result != 0:
logging.error(f"Failed to create video. Command returned {result}")
def concatenate_videos(
self,
videos_path: Path,
output_path: Path,
video_numerator: str = "video_%08d.mp4",
):
"""Concatenates a sequence of videos using ffmpeg."""
videos = sorted(videos_path.glob("*.mp4"))
file = "file.txt"
with open(file, "w") as f:
for video in videos:
f.write(f"file '{video}'\n")
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
logging.info(f"Running command: {cmd}")
result = self.run_command(cmd)
if result != 0:
logging.error(f"Failed to concatenate videos. Command returned {result}")
os.remove(file)
def get_fps(self, video_path: Path) -> float:
"""Gets the frames per second (FPS) of a video."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
logging.debug(f"FPS of video {video_path}: {fps}")
return fps
def get_video_duration(self, video_path: Path) -> float:
"""Gets the duration of a video in seconds."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.release()
duration = frame_count / fps
logging.debug(f"Duration of video {video_path}: {duration:.2f} seconds")
return duration
def run_command(self, cmd: str) -> int:
try:
subprocess.run(
cmd,
shell=True,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return 0
except subprocess.CalledProcessError as e:
logging.error(f"Command failed with error: {e}")
return e.returncode
def video_to_frames_generator(
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
) -> Generator[tuple[np.ndarray, ...], None, None]:
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
frames_per_chunk = int(fps * chunk_seconds)
while True:
paths = []
for _ in range(frames_per_chunk):
ret, frame = cap.read()
if not ret:
cap.release()
return
paths.append(frame)
yield tuple(paths)
def images_to_video_pipeline(
self,
frames: Iterable[np.ndarray],
output_path: Path,
width: int,
height: int,
fps: float,
):
pipeline = subprocess.Popen(
[
"ffmpeg",
"-y",
"-f", "rawvideo",
"-vcodec", "rawvideo",
"-pix_fmt", "bgr24",
"-s", f"{width}x{height}",
"-r", str(fps),
"-i", "-",
"-an",
"-vcodec", "libx264",
"-pix_fmt", "yuv420p",
str(output_path),
],
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL
)
if pipeline.stdin is None:
raise Exception("STDIN closed")
for frame in frames:
pipeline.stdin.write(frame.tobytes())
pipeline.stdin.close()
pipeline.wait()
def get_size(self, video_path: Path) -> tuple[int, int]:
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
return width, height