Compare commits
11 Commits
829d0c8c59
...
dev-cuda-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7addcf051c | ||
|
|
0c871c2314 | ||
|
|
61f8e0abe1 | ||
|
|
faf7aa8e81 | ||
|
|
bc09cd7b6c | ||
|
|
be794539ac | ||
|
|
28e51d1c5e | ||
|
|
97ca8b19f8 | ||
|
|
4fc13db0e8 | ||
|
|
c984b38904 | ||
|
|
888cdb3151 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -175,5 +175,6 @@ cython_debug/
|
||||
.pypirc
|
||||
|
||||
|
||||
.DS_Store
|
||||
source/
|
||||
output/
|
||||
BIN
example/frame_01.png
Normal file
BIN
example/frame_01.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
example/frame_02.png
Normal file
BIN
example/frame_02.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 57 KiB |
135
interpolator.py
135
interpolator.py
@@ -1,135 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
|
||||
from src.utils import utils
|
||||
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||
from src.utils.build import build_from_cfg
|
||||
from src.utils.padder import InputPadder
|
||||
|
||||
|
||||
class Anchor:
|
||||
def __init__(self, resolution: int, memory: int, memory_bias: int) -> None:
|
||||
self.resolution = resolution
|
||||
self.memory = memory
|
||||
self.memory_bias = memory_bias
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
|
||||
"""Initializes the ModelRunner with configuration and checkpoint.
|
||||
|
||||
Args:
|
||||
config (Path): Path to model configuration in YAML format
|
||||
ckpt_path (Path): Path to model checkpoint in .pth format
|
||||
device (torch.device): Device to load the model on
|
||||
"""
|
||||
omega_config = OmegaConf.load(config)
|
||||
network_config: DictConfig = omega_config.network
|
||||
logging.info(
|
||||
f"Loaded network configuration: {network_config} from [{ckpt_path}]"
|
||||
)
|
||||
model = build_from_cfg(network_config)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
model = model.to(get_device())
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
|
||||
def get_vram_available(device: torch.device) -> int:
|
||||
"""Returns the available VRAM in bytes."""
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_properties(
|
||||
device
|
||||
).total_memory - torch.cuda.memory_allocated(device)
|
||||
elif device.type == "mps" and torch.mps.is_available():
|
||||
# MPS does not provide a way to query available memory, so we return a large number to avoid issues
|
||||
return torch.mps.recommended_max_memory()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Detects and returns the best available device for PyTorch computation.
|
||||
|
||||
Returns:
|
||||
torch.device: CUDA device if available, MPS device for Apple Silicon if available, otherwise CPU.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
logging.info("Using CUDA-enabled GPU")
|
||||
return torch.device("cuda")
|
||||
elif torch.mps.is_available():
|
||||
logging.info("Using Apple Silicon GPU (MPS)")
|
||||
return torch.device("mps")
|
||||
logging.info("No GPU available, using CPU")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
class ImageInterpolator:
|
||||
def __init__(self, device: torch.device, anchor: Anchor, model_runner: ModelRunner):
|
||||
self.device = device
|
||||
self.anchor = anchor
|
||||
self.vram_available = get_vram_available(device)
|
||||
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
||||
self.model_runner = model_runner
|
||||
logging.debug(
|
||||
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
||||
)
|
||||
|
||||
def interpolate(self, image1: Path, image2: Path, output_path: Path):
|
||||
logging.debug(f"Reading images: {image1} and {image2}")
|
||||
tensor1 = img2tensor(utils.read(image1)).to(self.device)
|
||||
tensor2 = img2tensor(utils.read(image2)).to(self.device)
|
||||
logging.debug(
|
||||
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||
)
|
||||
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
||||
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
||||
h, w = tensor1.shape[2], tensor1.shape[3]
|
||||
logging.debug(f"Interpolating images of size: {h}x{w}")
|
||||
|
||||
scale = self.scale(h, w)
|
||||
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
||||
padding = int(16 / scale)
|
||||
logging.debug(f"Calculated padding: {padding} pixels")
|
||||
padder = InputPadder(tensor1.shape, divisor=padding)
|
||||
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
||||
logging.debug(
|
||||
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
||||
)
|
||||
|
||||
tensor1_padded = tensor1_padded.to(self.device)
|
||||
tensor2_padded = tensor2_padded.to(self.device)
|
||||
logging.debug("Running model inference for interpolation")
|
||||
with torch.no_grad():
|
||||
interpolated = self.model_runner.model(
|
||||
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
||||
)["imgt_pred"]
|
||||
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||
(interpolated,) = padder.unpad(interpolated)
|
||||
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||
utils.write(output_path, tensor2img(interpolated.cpu()))
|
||||
logging.debug(f"Saved interpolated image to: {output_path}")
|
||||
|
||||
def scale(self, height: int, width: int) -> float:
|
||||
scale = (
|
||||
self.anchor.resolution
|
||||
/ (height * width)
|
||||
* np.sqrt(
|
||||
(self.vram_available - self.anchor.memory_bias) / self.anchor.memory
|
||||
)
|
||||
)
|
||||
scale = 1 if scale > 1 else scale
|
||||
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
||||
if scale < 1:
|
||||
logging.info(
|
||||
f"Due to the limited VRAM, the video will be scaled by {scale:.2f}"
|
||||
)
|
||||
return scale
|
||||
407
main.py
407
main.py
@@ -1,149 +1,28 @@
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from time import perf_counter
|
||||
from decimal import Decimal
|
||||
from cv2 import imwrite
|
||||
import tqdm
|
||||
|
||||
from interpolator import get_device
|
||||
from interpolator import ImageInterpolator
|
||||
from interpolator import ModelRunner, Anchor
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
from src.config import presets
|
||||
from src.utils.fs import FileSystem
|
||||
from src.utils.video import VideoMaker
|
||||
from src.interpolator import (
|
||||
ImageInterpolator,
|
||||
Anchor,
|
||||
get_device,
|
||||
get_vram_available,
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def move_images(src_dir: str, interpolated_dir: str, output_dir: str):
|
||||
src_dir = Path(src_dir)
|
||||
interpolated_dir = Path(interpolated_dir)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index = 0
|
||||
src_frames = sorted(src_dir.glob("img_*.png"))
|
||||
interp_frames = sorted(interpolated_dir.glob("img_*.png"))
|
||||
for i in range(len(src_frames)):
|
||||
output_frame = output_dir / f"img_{index:08d}.png"
|
||||
src_frames[i].rename(output_frame)
|
||||
index += 1
|
||||
|
||||
if i < len(interp_frames):
|
||||
output_interp = output_dir / f"img_{index:08d}.png"
|
||||
interp_frames[i].rename(output_interp)
|
||||
index += 1
|
||||
|
||||
|
||||
def build_file_list(moved_dir: str, list_path: str):
|
||||
import os
|
||||
moved_dir = Path(moved_dir)
|
||||
frames = sorted(moved_dir.glob("img_*.png"))
|
||||
print(frames[0])
|
||||
|
||||
with open(list_path, "w") as f:
|
||||
for frame in frames:
|
||||
f.write(f"file '{os.path.abspath(frame)}'\n")
|
||||
|
||||
|
||||
def build_ffmpeg_file_list(frames_dir: str, interpolated_dir: str, list_path: str):
|
||||
frames = sorted(Path(frames_dir).glob("img_*.png"))
|
||||
interps = sorted(Path(interpolated_dir).glob("img_*.png"))
|
||||
|
||||
if len(interps) != len(frames) - 1:
|
||||
raise ValueError("Interpolated frames must be N-1")
|
||||
|
||||
with open(list_path, "w") as f:
|
||||
for i in range(len(frames)):
|
||||
f.write(f"file '{frames[i].resolve().as_posix()}'\n")
|
||||
|
||||
if i < len(interps):
|
||||
f.write(f"file '{interps[i].resolve().as_posix()}'\n")
|
||||
|
||||
|
||||
def merge_with_ffmpeg(
|
||||
original_video: str,
|
||||
file_list: str,
|
||||
output_video: str,
|
||||
):
|
||||
cap = cv2.VideoCapture(original_video)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Cannot open original video")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap.release()
|
||||
|
||||
new_fps = Decimal(fps * 2)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-r", str(new_fps.quantize(Decimal("1.0000000000"))),
|
||||
"-f", "concat",
|
||||
"-safe", "0",
|
||||
"-i", file_list,
|
||||
"-c:v", "libx264rgb",
|
||||
output_video,
|
||||
]
|
||||
print("Running ffmpeg command:", " ".join(cmd))
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
|
||||
def video_frames_to_disk_generator(
|
||||
video_path: str | Path,
|
||||
output_dir: str | Path,
|
||||
chunk_seconds: int = 10
|
||||
):
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frames_per_chunk = int(fps * chunk_seconds)
|
||||
|
||||
frame_index = 0
|
||||
|
||||
while True:
|
||||
paths = []
|
||||
|
||||
for _ in range(frames_per_chunk):
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
cap.release()
|
||||
return
|
||||
|
||||
frame_path = output_dir / f"img_{frame_index:08d}.png"
|
||||
cv2.imwrite(str(frame_path), frame)
|
||||
|
||||
paths.append(frame_path)
|
||||
frame_index += 1
|
||||
|
||||
yield tuple(paths)
|
||||
|
||||
|
||||
def main():
|
||||
start = perf_counter()
|
||||
logging.info("Starting video interpolation process")
|
||||
config_path = Path("src/config/AMT-G.yaml")
|
||||
ckpt_path = Path("src/pretrained/amt-g.pth")
|
||||
video_path = Path("example/video.mp4")
|
||||
output_dir = Path("output/frames")
|
||||
output_interpolated_dir = Path("output/interpolated")
|
||||
output_interpolated_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = get_device()
|
||||
model_runner = ModelRunner(config_path, ckpt_path, device)
|
||||
def performing_warning_message(device: "torch.device"):
|
||||
if device.type in ("cpu", "mps"):
|
||||
if device.type == "mps":
|
||||
logging.warning(
|
||||
@@ -153,87 +32,199 @@ def main():
|
||||
logging.warning(
|
||||
"Running on CPU may be very slow. Consider using a GPU for better performance."
|
||||
)
|
||||
anchor = Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
|
||||
elif device.type == "cuda":
|
||||
anchor = Anchor(
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
|
||||
def init_fs(base_path: Path) -> FileSystem:
|
||||
fs = FileSystem(base_path)
|
||||
fs.clear_directory(fs.frames_path)
|
||||
fs.clear_directory(fs.interpolated_path)
|
||||
fs.clear_directory(fs.moved_path)
|
||||
fs.clear_directory(fs.video_part_path)
|
||||
return fs
|
||||
|
||||
|
||||
def init_video_maker() -> VideoMaker:
|
||||
return VideoMaker()
|
||||
|
||||
|
||||
def init_device() -> "torch.device":
|
||||
device = get_device()
|
||||
performing_warning_message(device)
|
||||
vram_available = get_vram_available(device)
|
||||
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
||||
return device
|
||||
|
||||
|
||||
def init_anchor(device: "torch.device") -> Anchor:
|
||||
if device.type in ("cpu", "mps"):
|
||||
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
|
||||
elif device.type == "cuda":
|
||||
return Anchor(
|
||||
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported device type: {device.type}")
|
||||
|
||||
interpolator = ImageInterpolator(device, anchor, model_runner)
|
||||
|
||||
loaded_time = perf_counter() - start
|
||||
logging.info(f"Model loaded and initialized in {loaded_time:.2f} seconds")
|
||||
|
||||
prev_frame_path = None
|
||||
frame_count = 0
|
||||
for frame_paths in video_frames_to_disk_generator(video_path, output_dir):
|
||||
logging.info(f"Processing frames: {len(frame_paths)}")
|
||||
|
||||
if prev_frame_path is not None:
|
||||
img1 = prev_frame_path[-1]
|
||||
img2 = frame_paths[0]
|
||||
output_path = output_interpolated_dir / f"img_{frame_count:08d}.png"
|
||||
interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
frame_count += 1
|
||||
for i in tqdm(range(len(frame_paths) - 1), desc="Interpolating frames"):
|
||||
img1 = frame_paths[i]
|
||||
img2 = frame_paths[i + 1]
|
||||
output_path = output_interpolated_dir / f"img_{frame_count:08d}.png"
|
||||
interpolator.interpolate(img1, img2, output_path)
|
||||
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||
frame_count += 1
|
||||
prev_frame_path = frame_paths
|
||||
total_time = perf_counter() - start
|
||||
logging.info(f"Video interpolation completed in {total_time:.2f} seconds")
|
||||
def init_model_runner(preset: presets.Preset, device: "torch.device") -> ModelRunner:
|
||||
return ModelRunner(preset, device)
|
||||
|
||||
|
||||
def builder():
|
||||
frames_dir = "output/frames"
|
||||
interpolated_dir = "output/interpolated"
|
||||
moved_dir = "output/moved"
|
||||
video_path = "example/video.mp4"
|
||||
output_video = "output/interpolated_video.mp4"
|
||||
move_images(frames_dir, interpolated_dir, moved_dir)
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Cannot open original video")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-framerate", str(fps * 2),
|
||||
"-i", f"{moved_dir}/img_%08d.png",
|
||||
"-i", video_path,
|
||||
"-c:v", "libx264",
|
||||
"-c:a", "copy",
|
||||
"-shortest",
|
||||
output_video,
|
||||
]
|
||||
logging.info("Running ffmpeg command to build final video: " + " ".join(cmd))
|
||||
subprocess.run(cmd, check=True)
|
||||
def init_interpolator(
|
||||
model_runner: ModelRunner, device: "torch.device"
|
||||
) -> ImageInterpolator:
|
||||
anchor = init_anchor(device)
|
||||
return ImageInterpolator(device, anchor, model_runner)
|
||||
|
||||
|
||||
def cleanup():
|
||||
import os
|
||||
import shutil
|
||||
frames_dir = "output/frames"
|
||||
interpolated_dir = "output/interpolated"
|
||||
moved_dir = "output/moved"
|
||||
os.makedirs(frames_dir, exist_ok=True)
|
||||
os.makedirs(interpolated_dir, exist_ok=True)
|
||||
os.makedirs(moved_dir, exist_ok=True)
|
||||
shutil.rmtree(frames_dir)
|
||||
shutil.rmtree(interpolated_dir)
|
||||
shutil.rmtree(moved_dir)
|
||||
class InterpolationPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
preset: presets.Preset,
|
||||
base_path: Path,
|
||||
):
|
||||
self.fs = init_fs(base_path)
|
||||
self.video_maker = init_video_maker()
|
||||
self.device = init_device()
|
||||
self.model_runner = init_model_runner(preset, self.device)
|
||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||
|
||||
def run(self, video_path: Path, output_video: str):
|
||||
prev_frames = tuple()
|
||||
interpolated_frames: list["np.ndarray"] = []
|
||||
part = 0
|
||||
chunk_seconds = 1
|
||||
length = self.video_maker.get_video_duration(video_path)
|
||||
last_part_seconds = 1 if length % chunk_seconds else 0
|
||||
total_parts = int(length // chunk_seconds) + last_part_seconds
|
||||
fps = self.video_maker.get_fps(video_path)
|
||||
logging.info(f"Video FPS: {fps}")
|
||||
fps *= 2 # Doubling FPS
|
||||
width, height = self.video_maker.get_size(video_path)
|
||||
for frames in self.video_maker.video_to_frames_generator(
|
||||
video_path, self.fs.frames_path, chunk_seconds
|
||||
):
|
||||
logging.info(f"Processing frames: {len(frames)}")
|
||||
if prev_frames:
|
||||
img1 = prev_frames[-1]
|
||||
img2 = frames[0]
|
||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||
interpolated_frames.append(img1_2)
|
||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||
self.video_maker.images_to_video_pipeline(
|
||||
generator, part_path, width, height, fps
|
||||
)
|
||||
interpolated_frames = []
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
part += 1
|
||||
for i in tqdm.tqdm(
|
||||
range(len(frames) - 1),
|
||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||
):
|
||||
img1 = frames[i]
|
||||
img2 = frames[i + 1]
|
||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||
interpolated_frames.append(img1_2)
|
||||
prev_frames = frames
|
||||
|
||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||
self.video_maker.images_to_video_pipeline(
|
||||
generator, part_path, width, height, fps
|
||||
)
|
||||
logging.info(f"Finished processing part {part:08d}")
|
||||
self._merge_video_parts(self.fs.output_path / output_video)
|
||||
logging.info(
|
||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||
)
|
||||
|
||||
def _save_images(
|
||||
self,
|
||||
source: tuple["np.ndarray", ...],
|
||||
interpolated: list["np.ndarray"],
|
||||
):
|
||||
logging.info("Saving images...")
|
||||
self.fs.clear_directory(self.fs.moved_path)
|
||||
index = 0
|
||||
for i, frame in enumerate(source):
|
||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
index += 1
|
||||
imwrite(name, frame)
|
||||
if i < len(interpolated):
|
||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||
index += 1
|
||||
imwrite(name, interpolated[i])
|
||||
logging.info("Success...")
|
||||
|
||||
def _merge_frames_to_video(self, output_video: Path, fps: float):
|
||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||
|
||||
def _merge_video_parts(self, output_video: Path):
|
||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||
self.fs.clear_directory(self.fs.video_part_path)
|
||||
|
||||
def _frame_generator(
|
||||
self,
|
||||
source: tuple["np.ndarray", ...],
|
||||
interpolated: list["np.ndarray"],
|
||||
):
|
||||
for i, frame in enumerate(source):
|
||||
yield frame
|
||||
if i < len(interpolated):
|
||||
yield interpolated[i]
|
||||
|
||||
|
||||
def runner(
|
||||
base_path: Path,
|
||||
video_path: Path,
|
||||
output_video: str,
|
||||
preset: presets.Preset = presets.LARGE,
|
||||
):
|
||||
pipeline = InterpolationPipeline(
|
||||
preset=preset,
|
||||
base_path=base_path,
|
||||
)
|
||||
pipeline.run(video_path, output_video)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-b", "--base_path", help="Base path", default="output")
|
||||
parser.add_argument(
|
||||
"-v", "--video_path", help="Video path", default="example/video.mp4"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
help="Output video name (example: 'interpolated_video.mp4')",
|
||||
default="interpolated_video.mp4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--preset",
|
||||
help="Model preset",
|
||||
choices=["small", "large", "global"],
|
||||
default="global",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
runner(
|
||||
base_path=Path(args.base_path),
|
||||
video_path=Path(args.video_path),
|
||||
output_video=args.output,
|
||||
preset=getattr(presets, args.preset.upper()),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cleanup()
|
||||
main()
|
||||
builder()
|
||||
cleanup()
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.utils.flow_utils import warp
|
||||
|
||||
|
||||
def resize(x, scale_factor):
|
||||
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
||||
|
||||
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
|
||||
nn.PReLU(out_channels)
|
||||
)
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, side_channels, bias=True):
|
||||
super(ResBlock, self).__init__()
|
||||
self.side_channels = side_channels
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(in_channels)
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(side_channels)
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(in_channels)
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(side_channels)
|
||||
)
|
||||
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
||||
self.prelu = nn.PReLU(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
||||
res_feat = out[:, :-self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels:, :, :]
|
||||
side_feat = self.conv2(side_feat)
|
||||
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
res_feat = out[:, :-self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels:, :, :]
|
||||
side_feat = self.conv4(side_feat)
|
||||
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
out = self.prelu(x + out)
|
||||
return out
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, channels, large=False):
|
||||
super(Encoder, self).__init__()
|
||||
self.channels = channels
|
||||
prev_ch = 3
|
||||
for idx, ch in enumerate(channels, 1):
|
||||
k = 7 if large and idx == 1 else 3
|
||||
p = 3 if k ==7 else 1
|
||||
self.register_module(f'pyramid{idx}',
|
||||
nn.Sequential(
|
||||
convrelu(prev_ch, ch, k, 2, p),
|
||||
convrelu(ch, ch, 3, 1, 1)
|
||||
))
|
||||
prev_ch = ch
|
||||
|
||||
def forward(self, in_x):
|
||||
fs = []
|
||||
for idx in range(len(self.channels)):
|
||||
out_x = getattr(self, f'pyramid{idx+1}')(in_x)
|
||||
fs.append(out_x)
|
||||
in_x = out_x
|
||||
return fs
|
||||
|
||||
class InitDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||
super().__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch*2+1, in_ch*2),
|
||||
ResBlock(in_ch*2, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
|
||||
)
|
||||
def forward(self, f0, f1, embt):
|
||||
h, w = f0.shape[2:]
|
||||
embt = embt.repeat(1, 1, h, w)
|
||||
out = self.convblock(torch.cat([f0, f1, embt], 1))
|
||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||
ft_ = out[:, 4:, ...]
|
||||
return flow0, flow1, ft_
|
||||
|
||||
class IntermediateDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||
super().__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch*3+4, in_ch*3),
|
||||
ResBlock(in_ch*3, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
|
||||
)
|
||||
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
|
||||
f0_warp = warp(f0, flow0_in)
|
||||
f1_warp = warp(f1, flow1_in)
|
||||
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)
|
||||
out = self.convblock(f_in)
|
||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||
ft_ = out[:, 4:, ...]
|
||||
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
|
||||
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
|
||||
return flow0, flow1, ft_
|
||||
@@ -1,69 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.flow_utils import warp
|
||||
from networks.blocks.ifrnet import (
|
||||
convrelu, resize,
|
||||
ResBlock,
|
||||
)
|
||||
|
||||
|
||||
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
|
||||
mask=None, img_res=None, mean=None):
|
||||
'''
|
||||
A parallel implementation of multiple flow field warping
|
||||
comb_block: An nn.Seqential object.
|
||||
img shape: [b, c, h, w]
|
||||
flow shape: [b, 2*num_flows, h, w]
|
||||
mask (opt):
|
||||
If 'mask' is None, the function conduct a simple average.
|
||||
img_res (opt):
|
||||
If 'img_res' is None, the function adds zero instead.
|
||||
mean (opt):
|
||||
If 'mean' is None, the function adds zero instead.
|
||||
'''
|
||||
b, c, h, w = flow0.shape
|
||||
num_flows = c // 2
|
||||
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
|
||||
mask = mask.reshape(b, num_flows, 1, h, w
|
||||
).reshape(-1, 1, h, w) if mask is not None else None
|
||||
img_res = img_res.reshape(b, num_flows, 3, h, w
|
||||
).reshape(-1, 3, h, w) if img_res is not None else 0
|
||||
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
|
||||
) if mean is not None else 0
|
||||
|
||||
img0_warp = warp(img0, flow0)
|
||||
img1_warp = warp(img1, flow1)
|
||||
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
||||
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
||||
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
|
||||
return imgt_pred
|
||||
|
||||
|
||||
class MultiFlowDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch, num_flows=3):
|
||||
super(MultiFlowDecoder, self).__init__()
|
||||
self.num_flows = num_flows
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch*3+4, in_ch*3),
|
||||
ResBlock(in_ch*3, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, flow0, flow1):
|
||||
n = self.num_flows
|
||||
f0_warp = warp(f0, flow0)
|
||||
f1_warp = warp(f1, flow1)
|
||||
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
|
||||
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
|
||||
mask = torch.sigmoid(mask)
|
||||
|
||||
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
|
||||
).repeat(1, self.num_flows, 1, 1)
|
||||
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
|
||||
).repeat(1, self.num_flows, 1, 1)
|
||||
|
||||
return flow0, flow1, mask, img_res
|
||||
8
onnx_export.py
Normal file
8
onnx_export.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
from src.export_to_onnx import export_to_onnx
|
||||
from src.config.presets import SMALL
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cuda")
|
||||
export_to_onnx(SMALL, "src/pretrained/amt_s.onnx", device)
|
||||
@@ -7,8 +7,14 @@ requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"imageio>=2.37.3",
|
||||
"numpy>=2.4.4",
|
||||
"nvidia-modelopt[all]>=0.33.1",
|
||||
"omegaconf>=2.3.0",
|
||||
"onnx>=1.21.0",
|
||||
"onnxscript>=0.6.2",
|
||||
"opencv-python>=4.13.0.92",
|
||||
"torch>=2.11.0",
|
||||
"tensorrt>=10.16.1.11",
|
||||
"torch==2.5.1",
|
||||
"torch-tensorrt>=2.5.0",
|
||||
"torchvision>=0.20.1",
|
||||
"tqdm>=4.67.3",
|
||||
]
|
||||
|
||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: networks.AMT-G.Model
|
||||
name: src.networks.AMT-G.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
|
||||
62
src/config/AMT-L.yaml
Normal file
62
src/config/AMT-L.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
exp_name: floloss1e-2_300epoch_bs24_lr2e-4
|
||||
seed: 2023
|
||||
epochs: 300
|
||||
distributed: true
|
||||
lr: 2e-4
|
||||
lr_min: 2e-5
|
||||
weight_decay: 0.0
|
||||
resume_state: null
|
||||
save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: src.networks.AMT-L.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
num_flows: 5
|
||||
data:
|
||||
train:
|
||||
name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset
|
||||
params:
|
||||
dataset_dir: data/vimeo_triplet
|
||||
val:
|
||||
name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset
|
||||
params:
|
||||
dataset_dir: data/vimeo_triplet
|
||||
train_loader:
|
||||
batch_size: 24
|
||||
num_workers: 12
|
||||
val_loader:
|
||||
batch_size: 24
|
||||
num_workers: 3
|
||||
|
||||
logger:
|
||||
use_wandb: true
|
||||
resume_id: null
|
||||
|
||||
losses:
|
||||
- {
|
||||
name: losses.loss.CharbonnierLoss,
|
||||
nickname: l_rec,
|
||||
params: {
|
||||
loss_weight: 1.0,
|
||||
keys: [imgt_pred, imgt]
|
||||
}
|
||||
}
|
||||
- {
|
||||
name: losses.loss.TernaryLoss,
|
||||
nickname: l_ter,
|
||||
params: {
|
||||
loss_weight: 1.0,
|
||||
keys: [imgt_pred, imgt]
|
||||
}
|
||||
}
|
||||
- {
|
||||
name: losses.loss.MultipleFlowLoss,
|
||||
nickname: l_flo,
|
||||
params: {
|
||||
loss_weight: 0.002,
|
||||
keys: [flow0_pred, flow1_pred, flow]
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: networks.AMT-S.Model
|
||||
name: src.networks.AMT-S.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
|
||||
26
src/config/presets.py
Normal file
26
src/config/presets.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Literal
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Preset:
|
||||
config: Path
|
||||
checkpoint: Path
|
||||
onnx: Path | None = None
|
||||
|
||||
|
||||
SMALL = Preset(
|
||||
config=Path("src/config/AMT-S.yaml"),
|
||||
checkpoint=Path("src/pretrained/amt-s.pth"),
|
||||
)
|
||||
|
||||
LARGE = Preset(
|
||||
config=Path("src/config/AMT-L.yaml"),
|
||||
checkpoint=Path("src/pretrained/amt-l.pth"),
|
||||
)
|
||||
|
||||
GLOBAL = Preset(
|
||||
config=Path("src/config/AMT-G.yaml"),
|
||||
checkpoint=Path("src/pretrained/amt-g.pth"),
|
||||
)
|
||||
29
src/export_to_onnx.py
Normal file
29
src/export_to_onnx.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
torchvision.disable_beta_transforms_warning()
|
||||
torch.backends.cudnn.enabled = False
|
||||
from .interpolator import ModelRunner
|
||||
from .config.presets import Preset
|
||||
|
||||
|
||||
def export_to_onnx(preset: Preset, output_path: str, device: torch.device):
|
||||
model_runner = ModelRunner(preset, device)
|
||||
# model_runner.model.eval()
|
||||
dummy_input = model_runner.get_dummy_input()
|
||||
torch.onnx.export(
|
||||
model_runner.model,
|
||||
dummy_input,
|
||||
output_path,
|
||||
opset_version=17,
|
||||
input_names=['img0', 'img1', 'embt'],
|
||||
output_names=["imgt_pred"],
|
||||
dynamic_axes={
|
||||
"img0": {0: "batch", 2: "height", 3: "width"},
|
||||
"img1": {0: "batch", 2: "height", 3: "width"},
|
||||
"embt": {0: "batch", 2: "height", 3: "width"},
|
||||
"imgt_pred": {0: "batch", 2: "height", 3: "width"},
|
||||
},
|
||||
dynamo=True,
|
||||
use_external_data_format=False,
|
||||
)
|
||||
186
src/interpolator.py
Normal file
186
src/interpolator.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cv2 import imread
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
|
||||
from src.config.presets import Preset
|
||||
from src.utils.torch import img2tensor, tensor2img
|
||||
from src.utils.build import build_from_cfg
|
||||
|
||||
|
||||
class Anchor:
|
||||
def __init__(self, resolution: int, memory: int, memory_bias: int) -> None:
|
||||
self.resolution = resolution
|
||||
self.memory = memory
|
||||
self.memory_bias = memory_bias
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
|
||||
|
||||
|
||||
class ONNXWrapper:
|
||||
def __init__(self, path):
|
||||
self.session = ort.InferenceSession(path)
|
||||
|
||||
self.input_names = [i.name for i in self.session.get_inputs()]
|
||||
self.output_names = [o.name for o in self.session.get_outputs()]
|
||||
|
||||
def __call__(self, tensor1, tensor2, embt):
|
||||
inputs = {
|
||||
self.input_names[0]: tensor1.cpu().numpy(),
|
||||
self.input_names[1]: tensor2.cpu().numpy(),
|
||||
self.input_names[2]: embt.cpu().numpy(),
|
||||
}
|
||||
|
||||
outputs = self.session.run(self.output_names, inputs)
|
||||
|
||||
return {"imgt_pred": torch.from_numpy(outputs[0])}
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, preset: Preset, device: torch.device) -> None:
|
||||
"""Initializes the ModelRunner with configuration and checkpoint.
|
||||
|
||||
Args:
|
||||
config (Path): Path to model configuration in YAML format
|
||||
ckpt_path (Path): Path to model checkpoint in .pth format
|
||||
device (torch.device): Device to load the model on
|
||||
"""
|
||||
self.model: Optional[torch.nn.Module] = None
|
||||
self.session: Optional[ONNXWrapper] = None
|
||||
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
||||
|
||||
if preset.onnx:
|
||||
self.session = ONNXWrapper(preset.onnx)
|
||||
self.device = device
|
||||
self.embt = self.embt.cpu().numpy()
|
||||
return
|
||||
|
||||
omega_config = OmegaConf.load(preset.config)
|
||||
network_config: DictConfig = omega_config.network
|
||||
logging.info(
|
||||
f"Loaded network configuration: {network_config} from [{preset.checkpoint}]"
|
||||
)
|
||||
model = build_from_cfg(network_config)
|
||||
checkpoint = torch.load(
|
||||
preset.checkpoint, map_location=device, weights_only=False
|
||||
)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
# self.model = torch.compile(model)
|
||||
self.model = model
|
||||
self.device = device
|
||||
# self.model = torch.compile(model, backend="tensorrt")
|
||||
if logging.getLogger().isEnabledFor(logging.DEBUG):
|
||||
for name, param in self.model.named_parameters():
|
||||
logging.debug(
|
||||
f"Parameter: {name}, shape: {param.shape}, dtype: {param.dtype}"
|
||||
)
|
||||
|
||||
def get_dummy_input(self):
|
||||
"""Generates a dummy input tensor for ONNX export."""
|
||||
return (
|
||||
img2tensor(imread(filename="example/frame_01.png"), self.device),
|
||||
img2tensor(imread(filename="example/frame_02.png"), self.device),
|
||||
self.embt,
|
||||
)
|
||||
|
||||
def run(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
|
||||
"""Runs the model inference to interpolate between two images.
|
||||
|
||||
Args:
|
||||
image1 (np.ndarray): First input image as a NumPy array
|
||||
image2 (np.ndarray): Second input image as a NumPy array
|
||||
|
||||
Returns:
|
||||
np.ndarray: Interpolated image as a NumPy array
|
||||
"""
|
||||
if self.session:
|
||||
image1 = img2tensor(image1, self.device).cpu().numpy()
|
||||
image2 = img2tensor(image2, self.device).cpu().numpy()
|
||||
inputs = {
|
||||
"img0": image1,
|
||||
"img1": image2,
|
||||
"embt": self.embt,
|
||||
}
|
||||
outputs = self.session.session.run(None, inputs)
|
||||
return outputs[0]
|
||||
|
||||
tensor1 = img2tensor(image1, self.device)
|
||||
tensor2 = img2tensor(image2, self.device)
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device.type):
|
||||
interpolated = self.model(tensor1, tensor2, self.embt)["imgt_pred"]
|
||||
return tensor2img(interpolated.cpu())
|
||||
|
||||
|
||||
def get_vram_available(device: torch.device) -> int:
|
||||
"""Returns the available VRAM in bytes."""
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_properties(
|
||||
device
|
||||
).total_memory - torch.cuda.memory_allocated(device)
|
||||
elif device.type == "mps" and torch.mps.is_available():
|
||||
# MPS does not provide a way to query available memory, so we return a large number to avoid issues
|
||||
return torch.mps.recommended_max_memory()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Detects and returns the best available device for PyTorch computation.
|
||||
|
||||
Returns:
|
||||
torch.device: CUDA device if available, MPS device for Apple Silicon if available, otherwise CPU.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
logging.info("Using CUDA-enabled GPU")
|
||||
return torch.device("cuda")
|
||||
elif torch.mps.is_available():
|
||||
logging.info("Using Apple Silicon GPU (MPS)")
|
||||
return torch.device("mps")
|
||||
logging.info("No GPU available, using CPU")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
class ImageInterpolator:
|
||||
def __init__(self, device: torch.device, anchor: Anchor, model_runner: ModelRunner):
|
||||
self.device = device
|
||||
self.anchor = anchor
|
||||
self.vram_available = get_vram_available(device)
|
||||
self.model_runner = model_runner
|
||||
logging.debug(
|
||||
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
||||
)
|
||||
|
||||
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Interpolates between two images and saves the result.
|
||||
Args:
|
||||
image1 (Path): Path to the first input image (only png and jpg formats are supported)
|
||||
image2 (Path): Path to the second input image (only png and jpg formats are supported)
|
||||
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
||||
"""
|
||||
return self.model_runner.run(image1, image2)
|
||||
|
||||
def scale(self, height: int, width: int) -> float:
|
||||
scale = (
|
||||
self.anchor.resolution
|
||||
/ (height * width)
|
||||
* np.sqrt(
|
||||
(self.vram_available - self.anchor.memory_bias) / self.anchor.memory
|
||||
)
|
||||
)
|
||||
scale = 1 if scale > 1 else scale
|
||||
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
||||
if scale < 1:
|
||||
logging.info(
|
||||
f"Due to the limited VRAM, the video will be scaled by {scale:.2f}"
|
||||
)
|
||||
return scale
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from networks.blocks.feat_enc import LargeEncoder
|
||||
from networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from src.networks.blocks.feat_enc import LargeEncoder
|
||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -42,7 +44,7 @@ class Model(nn.Module):
|
||||
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
|
||||
)
|
||||
|
||||
def _get_updateblock(self, cdim, scale_factor=None):
|
||||
def _get_updateblock(self, cdim: int, scale_factor: Optional[float] = None):
|
||||
return BasicUpdateBlock(
|
||||
cdim=cdim,
|
||||
hidden_dim=192,
|
||||
@@ -55,7 +57,15 @@ class Model(nn.Module):
|
||||
radius=self.radius,
|
||||
)
|
||||
|
||||
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
||||
def _corr_scale_lookup(
|
||||
self,
|
||||
corr_fn: BidirCorrBlock,
|
||||
coord: torch.Tensor,
|
||||
flow0: torch.Tensor,
|
||||
flow1: torch.Tensor,
|
||||
embt: torch.Tensor,
|
||||
downsample: int = 1,
|
||||
):
|
||||
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
||||
# based on linear assumption
|
||||
t1_scale = 1.0 / embt
|
||||
@@ -70,7 +80,15 @@ class Model(nn.Module):
|
||||
flow = torch.cat([flow0, flow1], dim=1)
|
||||
return corr, flow
|
||||
|
||||
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
img0: torch.Tensor,
|
||||
img1: torch.Tensor,
|
||||
embt: torch.Tensor,
|
||||
scale_factor: float = 1.0,
|
||||
eval: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
mean_ = (
|
||||
torch.cat([img0, img1], 2)
|
||||
.mean(1, keepdim=True)
|
||||
@@ -1,38 +1,29 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from networks.blocks.raft import (
|
||||
coords_grid,
|
||||
BasicUpdateBlock, BidirCorrBlock
|
||||
)
|
||||
from networks.blocks.feat_enc import (
|
||||
BasicEncoder
|
||||
)
|
||||
from networks.blocks.ifrnet import (
|
||||
resize,
|
||||
Encoder,
|
||||
InitDecoder,
|
||||
IntermediateDecoder
|
||||
)
|
||||
from networks.blocks.multi_flow import (
|
||||
multi_flow_combine,
|
||||
MultiFlowDecoder
|
||||
)
|
||||
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||
from src.networks.blocks.feat_enc import BasicEncoder
|
||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
|
||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self,
|
||||
corr_radius=3,
|
||||
corr_lvls=4,
|
||||
num_flows=5,
|
||||
channels=[48, 64, 72, 128],
|
||||
skip_channels=48
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
corr_radius=3,
|
||||
corr_lvls=4,
|
||||
num_flows=5,
|
||||
channels=[48, 64, 72, 128],
|
||||
skip_channels=48,
|
||||
):
|
||||
super(Model, self).__init__()
|
||||
self.radius = corr_radius
|
||||
self.corr_levels = corr_lvls
|
||||
self.num_flows = num_flows
|
||||
|
||||
self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.)
|
||||
self.feat_encoder = BasicEncoder(
|
||||
output_dim=128, norm_fn="instance", dropout=0.0
|
||||
)
|
||||
self.encoder = Encoder([48, 64, 72, 128], large=True)
|
||||
|
||||
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
|
||||
@@ -45,22 +36,29 @@ class Model(nn.Module):
|
||||
self.update2 = self._get_updateblock(48, 4.0)
|
||||
|
||||
self.comb_block = nn.Sequential(
|
||||
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
|
||||
nn.PReLU(6*self.num_flows),
|
||||
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
|
||||
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
|
||||
nn.PReLU(6 * self.num_flows),
|
||||
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
|
||||
)
|
||||
|
||||
def _get_updateblock(self, cdim, scale_factor=None):
|
||||
return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48,
|
||||
corr_dim=256, corr_dim2=160, fc_dim=124,
|
||||
scale_factor=scale_factor, corr_levels=self.corr_levels,
|
||||
radius=self.radius)
|
||||
return BasicUpdateBlock(
|
||||
cdim=cdim,
|
||||
hidden_dim=128,
|
||||
flow_dim=48,
|
||||
corr_dim=256,
|
||||
corr_dim2=160,
|
||||
fc_dim=124,
|
||||
scale_factor=scale_factor,
|
||||
corr_levels=self.corr_levels,
|
||||
radius=self.radius,
|
||||
)
|
||||
|
||||
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
||||
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
||||
# based on linear assumption
|
||||
t1_scale = 1. / embt
|
||||
t0_scale = 1. / (1. - embt)
|
||||
t1_scale = 1.0 / embt
|
||||
t0_scale = 1.0 / (1.0 - embt)
|
||||
if downsample != 1:
|
||||
inv = 1 / downsample
|
||||
flow0 = inv * resize(flow0, scale_factor=inv)
|
||||
@@ -72,7 +70,12 @@ class Model(nn.Module):
|
||||
return corr, flow
|
||||
|
||||
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
||||
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
||||
mean_ = (
|
||||
torch.cat([img0, img1], 2)
|
||||
.mean(1, keepdim=True)
|
||||
.mean(2, keepdim=True)
|
||||
.mean(3, keepdim=True)
|
||||
)
|
||||
img0 = img0 - mean_
|
||||
img1 = img1 - mean_
|
||||
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
|
||||
@@ -80,8 +83,10 @@ class Model(nn.Module):
|
||||
b, _, h, w = img0_.shape
|
||||
coord = coords_grid(b, h // 8, w // 8, img0.device)
|
||||
|
||||
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
||||
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
|
||||
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
||||
corr_fn = BidirCorrBlock(
|
||||
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
|
||||
)
|
||||
|
||||
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
|
||||
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
|
||||
@@ -90,9 +95,9 @@ class Model(nn.Module):
|
||||
|
||||
######################################### the 4th decoder #########################################
|
||||
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
|
||||
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
|
||||
up_flow0_4, up_flow1_4,
|
||||
embt, downsample=1)
|
||||
corr_4, flow_4 = self._corr_scale_lookup(
|
||||
corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
|
||||
)
|
||||
|
||||
# residue update with lookup corr
|
||||
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
|
||||
@@ -102,10 +107,12 @@ class Model(nn.Module):
|
||||
ft_3_ = ft_3_ + delta_ft_3_
|
||||
|
||||
######################################### the 3rd decoder #########################################
|
||||
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
|
||||
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
|
||||
coord, up_flow0_3, up_flow1_3,
|
||||
embt, downsample=2)
|
||||
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
|
||||
ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
|
||||
)
|
||||
corr_3, flow_3 = self._corr_scale_lookup(
|
||||
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
|
||||
)
|
||||
|
||||
# residue update with lookup corr
|
||||
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
|
||||
@@ -115,10 +122,12 @@ class Model(nn.Module):
|
||||
ft_2_ = ft_2_ + delta_ft_2_
|
||||
|
||||
######################################### the 2nd decoder #########################################
|
||||
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
|
||||
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
|
||||
coord, up_flow0_2, up_flow1_2,
|
||||
embt, downsample=4)
|
||||
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
|
||||
ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
|
||||
)
|
||||
corr_2, flow_2 = self._corr_scale_lookup(
|
||||
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
|
||||
)
|
||||
|
||||
# residue update with lookup corr
|
||||
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
|
||||
@@ -128,28 +137,36 @@ class Model(nn.Module):
|
||||
ft_1_ = ft_1_ + delta_ft_1_
|
||||
|
||||
######################################### the 1st decoder #########################################
|
||||
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
|
||||
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
|
||||
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
|
||||
)
|
||||
|
||||
if scale_factor != 1.0:
|
||||
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
||||
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
||||
mask = resize(mask, scale_factor=(1.0/scale_factor))
|
||||
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
|
||||
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
|
||||
1.0 / scale_factor
|
||||
)
|
||||
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
|
||||
1.0 / scale_factor
|
||||
)
|
||||
mask = resize(mask, scale_factor=(1.0 / scale_factor))
|
||||
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
|
||||
|
||||
# Merge multiple predictions
|
||||
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
|
||||
mask, img_res, mean_)
|
||||
imgt_pred = multi_flow_combine(
|
||||
self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
|
||||
)
|
||||
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
||||
|
||||
if eval:
|
||||
return { 'imgt_pred': imgt_pred, }
|
||||
return {
|
||||
"imgt_pred": imgt_pred,
|
||||
}
|
||||
else:
|
||||
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
|
||||
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
|
||||
return {
|
||||
'imgt_pred': imgt_pred,
|
||||
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
||||
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
||||
'ft_pred': [ft_1_, ft_2_, ft_3_],
|
||||
"imgt_pred": imgt_pred,
|
||||
"flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
||||
"flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
||||
"ft_pred": [ft_1_, ft_2_, ft_3_],
|
||||
}
|
||||
|
||||
@@ -1,31 +1,20 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from networks.blocks.raft import (
|
||||
coords_grid,
|
||||
SmallUpdateBlock, BidirCorrBlock
|
||||
)
|
||||
from networks.blocks.feat_enc import (
|
||||
SmallEncoder
|
||||
)
|
||||
from networks.blocks.ifrnet import (
|
||||
resize,
|
||||
Encoder,
|
||||
InitDecoder,
|
||||
IntermediateDecoder
|
||||
)
|
||||
from networks.blocks.multi_flow import (
|
||||
multi_flow_combine,
|
||||
MultiFlowDecoder
|
||||
)
|
||||
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
|
||||
from src.networks.blocks.feat_enc import SmallEncoder
|
||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self,
|
||||
corr_radius=3,
|
||||
corr_lvls=4,
|
||||
num_flows=3,
|
||||
channels=[20, 32, 44, 56],
|
||||
skip_channels=20):
|
||||
def __init__(
|
||||
self,
|
||||
corr_radius=3,
|
||||
corr_lvls=4,
|
||||
num_flows=3,
|
||||
channels=[20, 32, 44, 56],
|
||||
skip_channels=20,
|
||||
):
|
||||
super(Model, self).__init__()
|
||||
self.radius = corr_radius
|
||||
self.corr_levels = corr_lvls
|
||||
@@ -33,7 +22,7 @@ class Model(nn.Module):
|
||||
self.channels = channels
|
||||
self.skip_channels = skip_channels
|
||||
|
||||
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.)
|
||||
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn="instance", dropout=0.0)
|
||||
self.encoder = Encoder(channels)
|
||||
|
||||
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
|
||||
@@ -46,21 +35,28 @@ class Model(nn.Module):
|
||||
self.update2 = self._get_updateblock(20, 4)
|
||||
|
||||
self.comb_block = nn.Sequential(
|
||||
nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1),
|
||||
nn.PReLU(6*num_flows),
|
||||
nn.Conv2d(6*num_flows, 3, 3, 1, 1),
|
||||
nn.Conv2d(3 * num_flows, 6 * num_flows, 3, 1, 1),
|
||||
nn.PReLU(6 * num_flows),
|
||||
nn.Conv2d(6 * num_flows, 3, 3, 1, 1),
|
||||
)
|
||||
|
||||
def _get_updateblock(self, cdim, scale_factor=None):
|
||||
return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64,
|
||||
fc_dim=68, scale_factor=scale_factor,
|
||||
corr_levels=self.corr_levels, radius=self.radius)
|
||||
return SmallUpdateBlock(
|
||||
cdim=cdim,
|
||||
hidden_dim=76,
|
||||
flow_dim=20,
|
||||
corr_dim=64,
|
||||
fc_dim=68,
|
||||
scale_factor=scale_factor,
|
||||
corr_levels=self.corr_levels,
|
||||
radius=self.radius,
|
||||
)
|
||||
|
||||
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
||||
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
||||
# based on linear assumption
|
||||
t1_scale = 1. / embt
|
||||
t0_scale = 1. / (1. - embt)
|
||||
t1_scale = 1.0 / embt
|
||||
t0_scale = 1.0 / (1.0 - embt)
|
||||
if downsample != 1:
|
||||
inv = 1 / downsample
|
||||
flow0 = inv * resize(flow0, scale_factor=inv)
|
||||
@@ -71,8 +67,20 @@ class Model(nn.Module):
|
||||
flow = torch.cat([flow0, flow1], dim=1)
|
||||
return corr, flow
|
||||
|
||||
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
||||
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
||||
def forward(
|
||||
self,
|
||||
img0: torch.Tensor,
|
||||
img1: torch.Tensor,
|
||||
embt: torch.Tensor,
|
||||
):
|
||||
scale_factor = 1.0
|
||||
eval = False
|
||||
mean_ = (
|
||||
torch.cat([img0, img1], 2)
|
||||
.mean(1, keepdim=True)
|
||||
.mean(2, keepdim=True)
|
||||
.mean(3, keepdim=True)
|
||||
)
|
||||
img0 = img0 - mean_
|
||||
img1 = img1 - mean_
|
||||
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
|
||||
@@ -80,8 +88,10 @@ class Model(nn.Module):
|
||||
b, _, h, w = img0_.shape
|
||||
coord = coords_grid(b, h // 8, w // 8, img0.device)
|
||||
|
||||
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
||||
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
|
||||
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
||||
corr_fn = BidirCorrBlock(
|
||||
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
|
||||
)
|
||||
|
||||
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
|
||||
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
|
||||
@@ -90,9 +100,9 @@ class Model(nn.Module):
|
||||
|
||||
######################################### the 4th decoder #########################################
|
||||
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
|
||||
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
|
||||
up_flow0_4, up_flow1_4,
|
||||
embt, downsample=1)
|
||||
corr_4, flow_4 = self._corr_scale_lookup(
|
||||
corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
|
||||
)
|
||||
|
||||
# residue update with lookup corr
|
||||
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
|
||||
@@ -102,10 +112,12 @@ class Model(nn.Module):
|
||||
ft_3_ = ft_3_ + delta_ft_3_
|
||||
|
||||
######################################### the 3rd decoder #########################################
|
||||
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
|
||||
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
|
||||
coord, up_flow0_3, up_flow1_3,
|
||||
embt, downsample=2)
|
||||
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
|
||||
ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
|
||||
)
|
||||
corr_3, flow_3 = self._corr_scale_lookup(
|
||||
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
|
||||
)
|
||||
|
||||
# residue update with lookup corr
|
||||
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
|
||||
@@ -115,10 +127,12 @@ class Model(nn.Module):
|
||||
ft_2_ = ft_2_ + delta_ft_2_
|
||||
|
||||
######################################### the 2nd decoder #########################################
|
||||
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
|
||||
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
|
||||
coord, up_flow0_2, up_flow1_2,
|
||||
embt, downsample=4)
|
||||
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
|
||||
ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
|
||||
)
|
||||
corr_2, flow_2 = self._corr_scale_lookup(
|
||||
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
|
||||
)
|
||||
|
||||
# residue update with lookup corr
|
||||
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
|
||||
@@ -128,27 +142,36 @@ class Model(nn.Module):
|
||||
ft_1_ = ft_1_ + delta_ft_1_
|
||||
|
||||
######################################### the 1st decoder #########################################
|
||||
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
|
||||
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
|
||||
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
|
||||
)
|
||||
|
||||
if scale_factor != 1.0:
|
||||
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
||||
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
||||
mask = resize(mask, scale_factor=(1.0/scale_factor))
|
||||
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
|
||||
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
|
||||
1.0 / scale_factor
|
||||
)
|
||||
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
|
||||
1.0 / scale_factor
|
||||
)
|
||||
mask = resize(mask, scale_factor=(1.0 / scale_factor))
|
||||
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
|
||||
|
||||
# Merge multiple predictions
|
||||
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
|
||||
mask, img_res, mean_)
|
||||
imgt_pred = multi_flow_combine(
|
||||
self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
|
||||
)
|
||||
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
||||
|
||||
if eval:
|
||||
return { 'imgt_pred': imgt_pred, }
|
||||
return {
|
||||
"imgt_pred": imgt_pred,
|
||||
}
|
||||
else:
|
||||
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
|
||||
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
|
||||
return {
|
||||
'imgt_pred': imgt_pred,
|
||||
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
||||
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
||||
'ft_pred': [ft_1_, ft_2_, ft_3_],
|
||||
"imgt_pred": imgt_pred,
|
||||
"flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
||||
"flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
||||
"ft_pred": [ft_1_, ft_2_, ft_3_],
|
||||
}
|
||||
@@ -1,30 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.flow_utils import warp
|
||||
from networks.blocks.ifrnet import (
|
||||
convrelu, resize,
|
||||
ResBlock,
|
||||
)
|
||||
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
self.pyramid1 = nn.Sequential(
|
||||
convrelu(3, 32, 3, 2, 1),
|
||||
convrelu(32, 32, 3, 1, 1)
|
||||
convrelu(3, 32, 3, 2, 1), convrelu(32, 32, 3, 1, 1)
|
||||
)
|
||||
self.pyramid2 = nn.Sequential(
|
||||
convrelu(32, 48, 3, 2, 1),
|
||||
convrelu(48, 48, 3, 1, 1)
|
||||
convrelu(32, 48, 3, 2, 1), convrelu(48, 48, 3, 1, 1)
|
||||
)
|
||||
self.pyramid3 = nn.Sequential(
|
||||
convrelu(48, 72, 3, 2, 1),
|
||||
convrelu(72, 72, 3, 1, 1)
|
||||
convrelu(48, 72, 3, 2, 1), convrelu(72, 72, 3, 1, 1)
|
||||
)
|
||||
self.pyramid4 = nn.Sequential(
|
||||
convrelu(72, 96, 3, 2, 1),
|
||||
convrelu(96, 96, 3, 1, 1)
|
||||
convrelu(72, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
@@ -39,9 +32,9 @@ class Decoder4(nn.Module):
|
||||
def __init__(self):
|
||||
super(Decoder4, self).__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(192+1, 192),
|
||||
convrelu(192 + 1, 192),
|
||||
ResBlock(192, 32),
|
||||
nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True)
|
||||
nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, f0, f1, embt):
|
||||
@@ -58,7 +51,7 @@ class Decoder3(nn.Module):
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(220, 216),
|
||||
ResBlock(216, 32),
|
||||
nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True)
|
||||
nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
|
||||
@@ -75,7 +68,7 @@ class Decoder2(nn.Module):
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(148, 144),
|
||||
ResBlock(144, 32),
|
||||
nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True)
|
||||
nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
|
||||
@@ -92,7 +85,7 @@ class Decoder1(nn.Module):
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(100, 96),
|
||||
ResBlock(96, 32),
|
||||
nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True)
|
||||
nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
|
||||
@@ -113,7 +106,12 @@ class Model(nn.Module):
|
||||
self.decoder1 = Decoder1()
|
||||
|
||||
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
||||
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
||||
mean_ = (
|
||||
torch.cat([img0, img1], 2)
|
||||
.mean(1, keepdim=True)
|
||||
.mean(2, keepdim=True)
|
||||
.mean(3, keepdim=True)
|
||||
)
|
||||
img0 = img0 - mean_
|
||||
img1 = img1 - mean_
|
||||
|
||||
@@ -145,10 +143,14 @@ class Model(nn.Module):
|
||||
up_res_1 = out1[:, 5:]
|
||||
|
||||
if scale_factor != 1.0:
|
||||
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
||||
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
||||
up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor))
|
||||
up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor))
|
||||
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
|
||||
1.0 / scale_factor
|
||||
)
|
||||
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
|
||||
1.0 / scale_factor
|
||||
)
|
||||
up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor))
|
||||
up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor))
|
||||
|
||||
img0_warp = warp(img0, up_flow0_1)
|
||||
img1_warp = warp(img1, up_flow1_1)
|
||||
@@ -157,13 +159,15 @@ class Model(nn.Module):
|
||||
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
||||
|
||||
if eval:
|
||||
return { 'imgt_pred': imgt_pred, }
|
||||
return {
|
||||
"imgt_pred": imgt_pred,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'imgt_pred': imgt_pred,
|
||||
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
||||
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
||||
'ft_pred': [ft_1_, ft_2_, ft_3_],
|
||||
'img0_warp': img0_warp,
|
||||
'img1_warp': img1_warp
|
||||
"imgt_pred": imgt_pred,
|
||||
"flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
||||
"flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
||||
"ft_pred": [ft_1_, ft_2_, ft_3_],
|
||||
"img0_warp": img0_warp,
|
||||
"img1_warp": img1_warp,
|
||||
}
|
||||
@@ -1,40 +1,44 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
@@ -46,8 +50,8 @@ class BottleneckBlock(nn.Module):
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
@@ -58,38 +62,40 @@ class BottleneckBlock(nn.Module):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
@@ -100,8 +106,8 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
@@ -111,31 +117,31 @@ class ResidualBlock(nn.Module):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
@@ -147,7 +153,7 @@ class SmallEncoder(nn.Module):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
@@ -162,14 +168,15 @@ class SmallEncoder(nn.Module):
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
def forward(
|
||||
self, x: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]
|
||||
):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = None
|
||||
if is_list := isinstance(x, tuple) or isinstance(x, list):
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
x: torch.Tensor = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
@@ -183,33 +190,37 @@ class SmallEncoder(nn.Module):
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
if is_list and batch_dim is not None:
|
||||
return torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return super().__call__(*args, **kwds)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(72, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
@@ -222,7 +233,7 @@ class BasicEncoder(nn.Module):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
@@ -237,7 +248,6 @@ class BasicEncoder(nn.Module):
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
@@ -264,21 +274,22 @@ class BasicEncoder(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LargeEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
||||
super(LargeEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
@@ -299,7 +310,7 @@ class LargeEncoder(nn.Module):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
@@ -314,7 +325,6 @@ class LargeEncoder(nn.Module):
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
159
src/networks/blocks/ifrnet.py
Executable file
159
src/networks/blocks/ifrnet.py
Executable file
@@ -0,0 +1,159 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.utils.flow_utils import warp
|
||||
|
||||
|
||||
def resize(x: torch.Tensor, scale_factor: float) -> torch.Tensor:
|
||||
return F.interpolate(
|
||||
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
|
||||
def convrelu(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(out_channels),
|
||||
)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, side_channels, bias=True):
|
||||
super(ResBlock, self).__init__()
|
||||
self.side_channels = side_channels
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
),
|
||||
nn.PReLU(in_channels),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
side_channels,
|
||||
side_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(side_channels),
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
),
|
||||
nn.PReLU(in_channels),
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
side_channels,
|
||||
side_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(side_channels),
|
||||
)
|
||||
self.conv5 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
)
|
||||
self.prelu = nn.PReLU(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
||||
res_feat = out[:, : -self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels :, :, :]
|
||||
side_feat = self.conv2(side_feat)
|
||||
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
res_feat = out[:, : -self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels :, :, :]
|
||||
side_feat = self.conv4(side_feat)
|
||||
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
out = self.prelu(x + out)
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, channels, large=False):
|
||||
super(Encoder, self).__init__()
|
||||
self.channels = channels
|
||||
prev_ch = 3
|
||||
for idx, ch in enumerate(channels, 1):
|
||||
k = 7 if large and idx == 1 else 3
|
||||
p = 3 if k == 7 else 1
|
||||
self.register_module(
|
||||
f"pyramid{idx}",
|
||||
nn.Sequential(
|
||||
convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1)
|
||||
),
|
||||
)
|
||||
prev_ch = ch
|
||||
|
||||
def forward(self, in_x):
|
||||
fs = []
|
||||
for idx in range(len(self.channels)):
|
||||
out_x = getattr(self, f"pyramid{idx + 1}")(in_x)
|
||||
fs.append(out_x)
|
||||
in_x = out_x
|
||||
return fs
|
||||
|
||||
|
||||
class InitDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||
super().__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * 2 + 1, in_ch * 2),
|
||||
ResBlock(in_ch * 2, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, f0, f1, embt):
|
||||
h, w = f0.shape[2:]
|
||||
embt = embt.repeat(1, 1, h, w)
|
||||
out = self.convblock(torch.cat([f0, f1, embt], 1))
|
||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||
ft_ = out[:, 4:, ...]
|
||||
return flow0, flow1, ft_
|
||||
|
||||
|
||||
class IntermediateDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||
super().__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * 3 + 4, in_ch * 3),
|
||||
ResBlock(in_ch * 3, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
|
||||
f0_warp = warp(f0, flow0_in)
|
||||
f1_warp = warp(f1, flow1_in)
|
||||
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)
|
||||
out = self.convblock(f_in)
|
||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||
ft_ = out[:, 4:, ...]
|
||||
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
|
||||
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
|
||||
return flow0, flow1, ft_
|
||||
80
src/networks/blocks/multi_flow.py
Executable file
80
src/networks/blocks/multi_flow.py
Executable file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.flow_utils import warp
|
||||
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
|
||||
|
||||
|
||||
def multi_flow_combine(
|
||||
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
|
||||
):
|
||||
"""
|
||||
A parallel implementation of multiple flow field warping
|
||||
comb_block: An nn.Seqential object.
|
||||
img shape: [b, c, h, w]
|
||||
flow shape: [b, 2*num_flows, h, w]
|
||||
mask (opt):
|
||||
If 'mask' is None, the function conduct a simple average.
|
||||
img_res (opt):
|
||||
If 'img_res' is None, the function adds zero instead.
|
||||
mean (opt):
|
||||
If 'mean' is None, the function adds zero instead.
|
||||
"""
|
||||
b, c, h, w = flow0.shape
|
||||
num_flows = c // 2
|
||||
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
|
||||
mask = (
|
||||
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
|
||||
if mask is not None
|
||||
else None
|
||||
)
|
||||
img_res = (
|
||||
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
|
||||
if img_res is not None
|
||||
else 0
|
||||
)
|
||||
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
mean = (
|
||||
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
|
||||
if mean is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
img0_warp = warp(img0, flow0)
|
||||
img1_warp = warp(img1, flow1)
|
||||
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
||||
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
||||
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
|
||||
return imgt_pred
|
||||
|
||||
|
||||
class MultiFlowDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch, num_flows=3):
|
||||
super(MultiFlowDecoder, self).__init__()
|
||||
self.num_flows = num_flows
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * 3 + 4, in_ch * 3),
|
||||
ResBlock(in_ch * 3, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, flow0, flow1):
|
||||
n = self.num_flows
|
||||
f0_warp = warp(f0, flow0)
|
||||
f1_warp = warp(f1, flow1)
|
||||
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
|
||||
delta_flow0, delta_flow1, mask, img_res = torch.split(
|
||||
out, [2 * n, 2 * n, n, 3 * n], 1
|
||||
)
|
||||
mask = torch.sigmoid(mask)
|
||||
|
||||
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat(
|
||||
1, self.num_flows, 1, 1
|
||||
)
|
||||
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat(
|
||||
1, self.num_flows, 1, 1
|
||||
)
|
||||
|
||||
return flow0, flow1, mask, img_res
|
||||
@@ -4,15 +4,17 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
def resize(x, scale_factor):
|
||||
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
||||
return F.interpolate(
|
||||
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
def bilinear_sampler(img: torch.Tensor, coords: torch.Tensor, mask=False):
|
||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
@@ -25,27 +27,36 @@ def bilinear_sampler(img, coords, mask=False):
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device),
|
||||
torch.arange(wd, device=device),
|
||||
indexing='ij')
|
||||
coords = torch.meshgrid(
|
||||
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
|
||||
)
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
|
||||
corr_levels=4, radius=3, scale_factor=None):
|
||||
def __init__(
|
||||
self,
|
||||
cdim,
|
||||
hidden_dim,
|
||||
flow_dim,
|
||||
corr_dim,
|
||||
fc_dim,
|
||||
corr_levels=4,
|
||||
radius=3,
|
||||
scale_factor=None,
|
||||
):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
cor_planes = corr_levels * (2 * radius + 1) **2
|
||||
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(corr_dim + flow_dim, fc_dim, 3, padding=1)
|
||||
|
||||
self.gru = nn.Sequential(
|
||||
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
||||
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
)
|
||||
@@ -65,8 +76,9 @@ class SmallUpdateBlock(nn.Module):
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, net, flow, corr):
|
||||
net = resize(net, 1 / self.scale_factor
|
||||
) if self.scale_factor is not None else net
|
||||
net = (
|
||||
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
||||
)
|
||||
cor = self.lrelu(self.convc1(corr))
|
||||
flo = self.lrelu(self.convf1(flow))
|
||||
flo = self.lrelu(self.convf2(flo))
|
||||
@@ -80,26 +92,39 @@ class SmallUpdateBlock(nn.Module):
|
||||
|
||||
if self.scale_factor is not None:
|
||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(
|
||||
delta_flow, scale_factor=self.scale_factor
|
||||
)
|
||||
|
||||
return delta_net, delta_flow
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
|
||||
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
|
||||
def __init__(
|
||||
self,
|
||||
cdim,
|
||||
hidden_dim,
|
||||
flow_dim,
|
||||
corr_dim,
|
||||
corr_dim2,
|
||||
fc_dim,
|
||||
corr_levels=4,
|
||||
radius=3,
|
||||
scale_factor=None,
|
||||
out_num=1,
|
||||
):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
cor_planes = corr_levels * (2 * radius + 1) **2
|
||||
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
|
||||
|
||||
self.gru = nn.Sequential(
|
||||
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
||||
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
)
|
||||
@@ -113,14 +138,15 @@ class BasicUpdateBlock(nn.Module):
|
||||
self.flow_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
|
||||
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
|
||||
)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, net, flow, corr):
|
||||
net = resize(net, 1 / self.scale_factor
|
||||
) if self.scale_factor is not None else net
|
||||
net = (
|
||||
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
||||
)
|
||||
cor = self.lrelu(self.convc1(corr))
|
||||
cor = self.lrelu(self.convc2(cor))
|
||||
flo = self.lrelu(self.convf1(flow))
|
||||
@@ -135,38 +161,44 @@ class BasicUpdateBlock(nn.Module):
|
||||
|
||||
if self.scale_factor is not None:
|
||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(
|
||||
delta_flow, scale_factor=self.scale_factor
|
||||
)
|
||||
return delta_net, delta_flow
|
||||
|
||||
|
||||
class BidirCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
def __init__(
|
||||
self, fmap1: torch.Tensor, fmap2: torch.Tensor, num_levels=4, radius=4
|
||||
):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
self.corr_pyramid_T = []
|
||||
self.corr_pyramid: list[torch.Tensor] = []
|
||||
self.corr_pyramid_T: list[torch.Tensor] = []
|
||||
|
||||
corr = BidirCorrBlock.corr(fmap1, fmap2)
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
|
||||
|
||||
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||
corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
|
||||
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
||||
corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
self.corr_pyramid_T.append(corr_T)
|
||||
|
||||
for _ in range(self.num_levels-1):
|
||||
for _ in range(self.num_levels - 1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
self.corr_pyramid_T.append(corr_T)
|
||||
|
||||
def __call__(self, coords0, coords1):
|
||||
def __call__(self, coords0: torch.Tensor, coords1: torch.Tensor):
|
||||
r = self.radius
|
||||
coords0 = coords0.permute(0, 2, 3, 1)
|
||||
coords1 = coords1.permute(0, 2, 3, 1)
|
||||
assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
||||
assert coords0.shape == coords1.shape, (
|
||||
f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
||||
)
|
||||
batch, h1, w1, _ = coords0.shape
|
||||
|
||||
out_pyramid = []
|
||||
@@ -175,15 +207,15 @@ class BidirCorrBlock:
|
||||
corr = self.corr_pyramid[i]
|
||||
corr_T = self.corr_pyramid_T[i]
|
||||
|
||||
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
||||
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
|
||||
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||
dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
|
||||
delta: torch.Tensor = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
|
||||
delta_lvl: torch.Tensor = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
|
||||
centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
coords_lvl_0 = centroid_lvl_0 + delta_lvl
|
||||
coords_lvl_1 = centroid_lvl_1 + delta_lvl
|
||||
centroid_lvl_0: torch.Tensor = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||
centroid_lvl_1: torch.Tensor = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||
coords_lvl_0: torch.Tensor = centroid_lvl_0 + delta_lvl
|
||||
coords_lvl_1: torch.Tensor = centroid_lvl_1 + delta_lvl
|
||||
|
||||
corr = bilinear_sampler(corr, coords_lvl_0)
|
||||
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
|
||||
@@ -194,14 +226,16 @@ class BidirCorrBlock:
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
out_T = torch.cat(out_pyramid_T, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
|
||||
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(
|
||||
0, 3, 1, 2
|
||||
).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
def corr(fmap1: torch.Tensor, fmap2: torch.Tensor):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||
fmap1 = fmap1.view(batch, dim, ht * wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht * wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
return corr * (dim**-0.5)
|
||||
BIN
src/pretrained/amt-l.pth
Normal file
BIN
src/pretrained/amt-l.pth
Normal file
Binary file not shown.
BIN
src/pretrained/amt-s.pth
Normal file
BIN
src/pretrained/amt-s.pth
Normal file
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import importlib
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
53
src/utils/fs.py
Normal file
53
src/utils/fs.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FileSystem:
|
||||
SOURCE_PATH = "source"
|
||||
OUTPUT_PATH = "output"
|
||||
FRAMES_PATH = "frames"
|
||||
INTERPOLATED_PATH = "interpolated"
|
||||
MOVED_PATH = "moved"
|
||||
VIDEO_PART_PATH = "video_parts"
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
self.base_path = base_path
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_directory(self, dir_name: str) -> Path:
|
||||
"""Creates a directory under the base path."""
|
||||
dir_path = self.base_path / dir_name
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
return dir_path
|
||||
|
||||
def clear_directory(self, dir_path: Path):
|
||||
"""Clears all files in the specified directory."""
|
||||
for item in dir_path.iterdir():
|
||||
if item.is_file():
|
||||
item.unlink()
|
||||
elif item.is_dir():
|
||||
self.clear_directory(item)
|
||||
item.rmdir()
|
||||
|
||||
@property
|
||||
def source_path(self) -> Path:
|
||||
return self.create_directory(self.SOURCE_PATH)
|
||||
|
||||
@property
|
||||
def output_path(self) -> Path:
|
||||
return self.create_directory(self.OUTPUT_PATH)
|
||||
|
||||
@property
|
||||
def frames_path(self) -> Path:
|
||||
return self.create_directory(self.FRAMES_PATH)
|
||||
|
||||
@property
|
||||
def interpolated_path(self) -> Path:
|
||||
return self.create_directory(self.INTERPOLATED_PATH)
|
||||
|
||||
@property
|
||||
def moved_path(self) -> Path:
|
||||
return self.create_directory(self.MOVED_PATH)
|
||||
|
||||
@property
|
||||
def video_part_path(self) -> Path:
|
||||
return self.create_directory(self.VIDEO_PART_PATH)
|
||||
@@ -5,23 +5,26 @@ import numpy as np
|
||||
|
||||
|
||||
def tensor2img(tensor: torch.Tensor):
|
||||
return (
|
||||
(tensor * 255.0)
|
||||
.detach()
|
||||
tensor = (
|
||||
tensor.mul(255.0)
|
||||
.clamp_(0, 255)
|
||||
.to(torch.uint8)
|
||||
.squeeze(0)
|
||||
.permute(1, 2, 0)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
||||
|
||||
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
|
||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||
if img.shape[-1] > 3:
|
||||
img = img[:, :, :3]
|
||||
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
|
||||
if device.type != "cuda":
|
||||
return tensor.float() / 255.0
|
||||
|
||||
return tensor.cuda(non_blocking=True).float().div_(255.0)
|
||||
|
||||
|
||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from imageio import imread, imwrite
|
||||
|
||||
|
||||
def read(file: Path) -> np.ndarray:
|
||||
readers = {
|
||||
".float3": readFloat,
|
||||
".flo": readFlow,
|
||||
".ppm": readImage,
|
||||
".pgm": readImage,
|
||||
".png": readImage,
|
||||
".jpg": readImage,
|
||||
".pfm": lambda f: readPFM(f)[0],
|
||||
}
|
||||
func = readers.get(file.suffix.lower())
|
||||
if func is None:
|
||||
raise Exception("don't know how to read %s" % file)
|
||||
return func(file)
|
||||
|
||||
|
||||
def write(file: Path, data: np.ndarray) -> None:
|
||||
writers = {
|
||||
".float3": writeFloat,
|
||||
".flo": writeFlow,
|
||||
".ppm": writeImage,
|
||||
".pgm": writeImage,
|
||||
".png": writeImage,
|
||||
".jpg": writeImage,
|
||||
".pfm": writePFM,
|
||||
}
|
||||
func = writers.get(file.suffix.lower())
|
||||
if func is None:
|
||||
raise Exception("don't know how to write %s" % file)
|
||||
return func(file, data)
|
||||
|
||||
|
||||
def readPFM(file: Path):
|
||||
data = open(file, "rb")
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = data.readline().rstrip()
|
||||
if header.decode("ascii") == "PF":
|
||||
color = True
|
||||
elif header.decode("ascii") == "Pf":
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Not a PFM file.")
|
||||
|
||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", data.readline().decode("ascii"))
|
||||
if dim_match:
|
||||
width, height = list(map(int, dim_match.groups()))
|
||||
else:
|
||||
raise Exception("Malformed PFM header.")
|
||||
|
||||
scale = float(data.readline().decode("ascii").rstrip())
|
||||
if scale < 0:
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
endian = ">"
|
||||
|
||||
result = np.fromfile(data, endian + "f")
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
result = np.reshape(result, shape)
|
||||
result = np.flipud(result)
|
||||
return result, scale
|
||||
|
||||
|
||||
def writePFM(file: Path, image: np.ndarray, scale=1):
|
||||
data = open(file, "wb")
|
||||
|
||||
color = None
|
||||
|
||||
if image.dtype.name != "float32":
|
||||
raise Exception("Image dtype must be float32.")
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
color = True
|
||||
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||
|
||||
data.write("PF\n" if color else "Pf\n".encode()) # type: ignore
|
||||
data.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||
scale = -scale
|
||||
|
||||
data.write("%f\n".encode() % scale)
|
||||
|
||||
image.tofile(data)
|
||||
|
||||
|
||||
def readFlow(file: Path):
|
||||
if file.suffix.lower() == ".pfm":
|
||||
return readPFM(file)[0][:, :, 0:2]
|
||||
|
||||
f = open(file, "rb")
|
||||
|
||||
header = f.read(4)
|
||||
if header.decode("utf-8") != "PIEH":
|
||||
raise Exception("Flow file header does not contain PIEH")
|
||||
|
||||
width = np.fromfile(f, np.int32, 1).squeeze()
|
||||
height = np.fromfile(f, np.int32, 1).squeeze()
|
||||
|
||||
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
|
||||
|
||||
return flow.astype(np.float32)
|
||||
|
||||
|
||||
def readImage(file: Path):
|
||||
if file.suffix.lower() == ".pfm":
|
||||
data = readPFM(file)[0]
|
||||
if len(data.shape) == 3:
|
||||
return data[:, :, 0:3]
|
||||
else:
|
||||
return data
|
||||
return imread(file)
|
||||
|
||||
|
||||
def writeImage(file: Path, data: np.ndarray):
|
||||
if file.suffix.lower() == ".pfm":
|
||||
return writePFM(file, data, 1)
|
||||
return imwrite(file, data)
|
||||
|
||||
|
||||
def writeFlow(file: Path, flow: np.ndarray):
|
||||
f = open(file, "wb")
|
||||
f.write("PIEH".encode("utf-8"))
|
||||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
||||
flow = flow.astype(np.float32)
|
||||
flow.tofile(f)
|
||||
|
||||
|
||||
def readFloat(file: Path):
|
||||
f = open(file, "rb")
|
||||
|
||||
if (f.readline().decode("utf-8")) != "float\n":
|
||||
raise Exception("float file %s did not contain <float> keyword" % file)
|
||||
|
||||
dim = int(f.readline())
|
||||
|
||||
dims = []
|
||||
count = 1
|
||||
for _ in range(0, dim):
|
||||
d = int(f.readline())
|
||||
dims.append(d)
|
||||
count *= d
|
||||
|
||||
dims = list(reversed(dims))
|
||||
|
||||
data = np.fromfile(f, np.float32, count).reshape(dims)
|
||||
if dim > 2:
|
||||
data = np.transpose(data, (2, 1, 0))
|
||||
data = np.transpose(data, (1, 0, 2))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def writeFloat(file: Path, data: np.ndarray):
|
||||
f = open(file, "wb")
|
||||
|
||||
dim = len(data.shape)
|
||||
if dim > 3:
|
||||
raise Exception("bad float file dimension: %d" % dim)
|
||||
|
||||
f.write(("float\n").encode("ascii"))
|
||||
f.write(("%d\n" % dim).encode("ascii"))
|
||||
|
||||
if dim == 1:
|
||||
f.write(("%d\n" % data.shape[0]).encode("ascii"))
|
||||
else:
|
||||
f.write(("%d\n" % data.shape[1]).encode("ascii"))
|
||||
f.write(("%d\n" % data.shape[0]).encode("ascii"))
|
||||
for i in range(2, dim):
|
||||
f.write(("%d\n" % data.shape[i]).encode("ascii"))
|
||||
|
||||
data = data.astype(np.float32)
|
||||
if dim == 2:
|
||||
data.tofile(f)
|
||||
|
||||
else:
|
||||
np.transpose(data, (2, 0, 1)).tofile(f)
|
||||
147
src/utils/video.py
Normal file
147
src/utils/video.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VideoMaker:
|
||||
def images_to_video(
|
||||
self,
|
||||
images_path: Path,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
image_numerator: str = "img_%08d.png",
|
||||
):
|
||||
"""Converts a sequence of images to a video using ffmpeg."""
|
||||
cmd = f"ffmpeg -framerate {fps} -i {images_path / image_numerator} -c:v libx264 -pix_fmt yuv420p {output_path}"
|
||||
logging.info(f"Running command: {cmd}")
|
||||
result = self.run_command(cmd)
|
||||
if result != 0:
|
||||
logging.error(f"Failed to create video. Command returned {result}")
|
||||
|
||||
def concatenate_videos(
|
||||
self,
|
||||
videos_path: Path,
|
||||
output_path: Path,
|
||||
video_numerator: str = "video_%08d.mp4",
|
||||
):
|
||||
"""Concatenates a sequence of videos using ffmpeg."""
|
||||
|
||||
videos = sorted(videos_path.glob("*.mp4"))
|
||||
file = "file.txt"
|
||||
with open(file, "w") as f:
|
||||
for video in videos:
|
||||
f.write(f"file '{video}'\n")
|
||||
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
|
||||
logging.info(f"Running command: {cmd}")
|
||||
result = self.run_command(cmd)
|
||||
if result != 0:
|
||||
logging.error(f"Failed to concatenate videos. Command returned {result}")
|
||||
os.remove(file)
|
||||
|
||||
def get_fps(self, video_path: Path) -> float:
|
||||
"""Gets the frames per second (FPS) of a video."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap.release()
|
||||
logging.debug(f"FPS of video {video_path}: {fps}")
|
||||
return fps
|
||||
|
||||
def get_video_duration(self, video_path: Path) -> float:
|
||||
"""Gets the duration of a video in seconds."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
cap.release()
|
||||
duration = frame_count / fps
|
||||
logging.debug(f"Duration of video {video_path}: {duration:.2f} seconds")
|
||||
return duration
|
||||
|
||||
def run_command(self, cmd: str) -> int:
|
||||
try:
|
||||
subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return 0
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Command failed with error: {e}")
|
||||
return e.returncode
|
||||
|
||||
def video_to_frames_generator(
|
||||
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
|
||||
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
||||
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frames_per_chunk = int(fps * chunk_seconds)
|
||||
|
||||
while True:
|
||||
paths = []
|
||||
for _ in range(frames_per_chunk):
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
cap.release()
|
||||
return
|
||||
paths.append(frame)
|
||||
yield tuple(paths)
|
||||
|
||||
def images_to_video_pipeline(
|
||||
self,
|
||||
frames: Iterable[np.ndarray],
|
||||
output_path: Path,
|
||||
width: int,
|
||||
height: int,
|
||||
fps: float,
|
||||
):
|
||||
pipeline = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f", "rawvideo",
|
||||
"-vcodec", "rawvideo",
|
||||
"-pix_fmt", "bgr24",
|
||||
"-s", f"{width}x{height}",
|
||||
"-r", str(fps),
|
||||
"-i", "-",
|
||||
"-an",
|
||||
"-vcodec", "libx264",
|
||||
"-pix_fmt", "yuv420p",
|
||||
str(output_path),
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL
|
||||
)
|
||||
if pipeline.stdin is None:
|
||||
raise Exception("STDIN closed")
|
||||
for frame in frames:
|
||||
pipeline.stdin.write(frame.tobytes())
|
||||
|
||||
pipeline.stdin.close()
|
||||
pipeline.wait()
|
||||
|
||||
def get_size(self, video_path: Path) -> tuple[int, int]:
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {video_path}")
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
cap.release()
|
||||
return width, height
|
||||
Reference in New Issue
Block a user