Compare commits
2 Commits
dev-cuda-u
...
c72e34f9dc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c72e34f9dc | ||
| 359f20c3c4 |
Binary file not shown.
|
Before Width: | Height: | Size: 50 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 57 KiB |
122
main.py
122
main.py
@@ -2,7 +2,6 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from cv2 import imwrite
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from src.config import presets
|
from src.config import presets
|
||||||
@@ -19,7 +18,6 @@ from src.interpolator import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def performing_warning_message(device: "torch.device"):
|
def performing_warning_message(device: "torch.device"):
|
||||||
@@ -55,7 +53,7 @@ def init_device() -> "torch.device":
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
performing_warning_message(device)
|
performing_warning_message(device)
|
||||||
vram_available = get_vram_available(device)
|
vram_available = get_vram_available(device)
|
||||||
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
@@ -70,8 +68,10 @@ def init_anchor(device: "torch.device") -> Anchor:
|
|||||||
raise Exception(f"Unsupported device type: {device.type}")
|
raise Exception(f"Unsupported device type: {device.type}")
|
||||||
|
|
||||||
|
|
||||||
def init_model_runner(preset: presets.Preset, device: "torch.device") -> ModelRunner:
|
def init_model_runner(
|
||||||
return ModelRunner(preset, device)
|
config: Path, checkpoint_path: Path, device: "torch.device"
|
||||||
|
) -> ModelRunner:
|
||||||
|
return ModelRunner(config, checkpoint_path, device)
|
||||||
|
|
||||||
|
|
||||||
def init_interpolator(
|
def init_interpolator(
|
||||||
@@ -84,58 +84,63 @@ def init_interpolator(
|
|||||||
class InterpolationPipeline:
|
class InterpolationPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preset: presets.Preset,
|
config: Path,
|
||||||
|
checkpoint_path: Path,
|
||||||
base_path: Path,
|
base_path: Path,
|
||||||
):
|
):
|
||||||
self.fs = init_fs(base_path)
|
self.fs = init_fs(base_path)
|
||||||
self.video_maker = init_video_maker()
|
self.video_maker = init_video_maker()
|
||||||
self.device = init_device()
|
self.device = init_device()
|
||||||
self.model_runner = init_model_runner(preset, self.device)
|
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
|
||||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||||
|
|
||||||
def run(self, video_path: Path, output_video: str):
|
def run(self, video_path: Path, output_video: str):
|
||||||
prev_frames = tuple()
|
prev_frame_path = None
|
||||||
interpolated_frames: list["np.ndarray"] = []
|
frame_count = 0
|
||||||
part = 0
|
part = 0
|
||||||
chunk_seconds = 1
|
source_frame_length = 0
|
||||||
|
chunk_seconds = 10
|
||||||
length = self.video_maker.get_video_duration(video_path)
|
length = self.video_maker.get_video_duration(video_path)
|
||||||
last_part_seconds = 1 if length % chunk_seconds else 0
|
last_part_seconds = 1 if length % chunk_seconds else 0
|
||||||
total_parts = int(length // chunk_seconds) + last_part_seconds
|
total_parts = int(length // chunk_seconds) + last_part_seconds
|
||||||
fps = self.video_maker.get_fps(video_path)
|
fps = self.video_maker.get_fps(video_path)
|
||||||
logging.info(f"Video FPS: {fps}")
|
logging.info(f"Video FPS: {fps}")
|
||||||
fps *= 2 # Doubling FPS
|
fps *= 2 # Doubling FPS
|
||||||
width, height = self.video_maker.get_size(video_path)
|
for frame_paths in self.video_maker.video_to_frames_generator(
|
||||||
for frames in self.video_maker.video_to_frames_generator(
|
|
||||||
video_path, self.fs.frames_path, chunk_seconds
|
video_path, self.fs.frames_path, chunk_seconds
|
||||||
):
|
):
|
||||||
logging.info(f"Processing frames: {len(frames)}")
|
logging.info(f"Processing frames: {len(frame_paths)}")
|
||||||
if prev_frames:
|
if prev_frame_path is not None:
|
||||||
img1 = prev_frames[-1]
|
img1 = prev_frame_path[-1]
|
||||||
img2 = frames[0]
|
img2 = frame_paths[0]
|
||||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png"
|
||||||
interpolated_frames.append(img1_2)
|
self.interpolator.interpolate(img1, img2, output_path)
|
||||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
self._merge_frames_to_video(
|
||||||
self.video_maker.images_to_video_pipeline(
|
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
||||||
generator, part_path, width, height, fps
|
fps,
|
||||||
|
source_frame_length=source_frame_length,
|
||||||
)
|
)
|
||||||
interpolated_frames = []
|
|
||||||
logging.info(f"Finished processing part {part:08d}")
|
logging.info(f"Finished processing part {part:08d}")
|
||||||
|
frame_count += 1
|
||||||
part += 1
|
part += 1
|
||||||
for i in tqdm.tqdm(
|
for i in tqdm.tqdm(
|
||||||
range(len(frames) - 1),
|
range(len(frame_paths) - 1),
|
||||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||||
):
|
):
|
||||||
img1 = frames[i]
|
img1 = frame_paths[i]
|
||||||
img2 = frames[i + 1]
|
img2 = frame_paths[i + 1]
|
||||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
output_path = self.fs.interpolated_path / f"img_{i:08d}.png"
|
||||||
interpolated_frames.append(img1_2)
|
self.interpolator.interpolate(img1, img2, output_path)
|
||||||
prev_frames = frames
|
logging.debug(f"Interpolated image saved to: {output_path}")
|
||||||
|
frame_count += 1
|
||||||
|
source_frame_length = len(frame_paths)
|
||||||
|
prev_frame_path = frame_paths
|
||||||
|
|
||||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
self._merge_frames_to_video(
|
||||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
self.fs.video_part_path / f"video_{part:08d}.mp4",
|
||||||
self.video_maker.images_to_video_pipeline(
|
fps,
|
||||||
generator, part_path, width, height, fps
|
source_frame_length=source_frame_length,
|
||||||
)
|
)
|
||||||
logging.info(f"Finished processing part {part:08d}")
|
logging.info(f"Finished processing part {part:08d}")
|
||||||
self._merge_video_parts(self.fs.output_path / output_video)
|
self._merge_video_parts(self.fs.output_path / output_video)
|
||||||
@@ -143,40 +148,32 @@ class InterpolationPipeline:
|
|||||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _save_images(
|
def _merge_frames_to_video(
|
||||||
self,
|
self, output_video: Path, fps: float, source_frame_length: int = 0
|
||||||
source: tuple["np.ndarray", ...],
|
|
||||||
interpolated: list["np.ndarray"],
|
|
||||||
):
|
):
|
||||||
logging.info("Saving images...")
|
self._move_frames(source_frame_length)
|
||||||
self.fs.clear_directory(self.fs.moved_path)
|
|
||||||
index = 0
|
|
||||||
for i, frame in enumerate(source):
|
|
||||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
|
||||||
index += 1
|
|
||||||
imwrite(name, frame)
|
|
||||||
if i < len(interpolated):
|
|
||||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
|
||||||
index += 1
|
|
||||||
imwrite(name, interpolated[i])
|
|
||||||
logging.info("Success...")
|
|
||||||
|
|
||||||
def _merge_frames_to_video(self, output_video: Path, fps: float):
|
|
||||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||||
|
|
||||||
def _merge_video_parts(self, output_video: Path):
|
def _merge_video_parts(self, output_video: Path):
|
||||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||||
self.fs.clear_directory(self.fs.video_part_path)
|
self.fs.clear_directory(self.fs.video_part_path)
|
||||||
|
|
||||||
def _frame_generator(
|
def _move_frames(self, source_frame_length: int = 0):
|
||||||
self,
|
self.fs.clear_directory(self.fs.moved_path)
|
||||||
source: tuple["np.ndarray", ...],
|
src_frames = sorted(self.fs.frames_path.glob("*.png"))
|
||||||
interpolated: list["np.ndarray"],
|
interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png"))
|
||||||
):
|
index = 0
|
||||||
for i, frame in enumerate(source):
|
for i in range(source_frame_length):
|
||||||
yield frame
|
moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png"
|
||||||
if i < len(interpolated):
|
src_frames[i].rename(moved_frame_path)
|
||||||
yield interpolated[i]
|
index += 1
|
||||||
|
if i < len(interpolated_frames):
|
||||||
|
moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png"
|
||||||
|
interpolated_frames[i].rename(moved_interpolated_path)
|
||||||
|
index += 1
|
||||||
|
logging.info(
|
||||||
|
f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def runner(
|
def runner(
|
||||||
@@ -186,7 +183,8 @@ def runner(
|
|||||||
preset: presets.Preset = presets.LARGE,
|
preset: presets.Preset = presets.LARGE,
|
||||||
):
|
):
|
||||||
pipeline = InterpolationPipeline(
|
pipeline = InterpolationPipeline(
|
||||||
preset=preset,
|
config=preset.config,
|
||||||
|
checkpoint_path=preset.checkpoint,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
)
|
)
|
||||||
pipeline.run(video_path, output_video)
|
pipeline.run(video_path, output_video)
|
||||||
@@ -222,7 +220,7 @@ def main():
|
|||||||
base_path=Path(args.base_path),
|
base_path=Path(args.base_path),
|
||||||
video_path=Path(args.video_path),
|
video_path=Path(args.video_path),
|
||||||
output_video=args.output,
|
output_video=args.output,
|
||||||
preset=getattr(presets, args.preset.upper()),
|
preset=getattr(presets, args.preset.upper())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
import torch
|
|
||||||
from src.export_to_onnx import export_to_onnx
|
|
||||||
from src.config.presets import SMALL
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cuda")
|
|
||||||
export_to_onnx(SMALL, "src/pretrained/amt_s.onnx", device)
|
|
||||||
@@ -7,14 +7,8 @@ requires-python = ">=3.12"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"imageio>=2.37.3",
|
"imageio>=2.37.3",
|
||||||
"numpy>=2.4.4",
|
"numpy>=2.4.4",
|
||||||
"nvidia-modelopt[all]>=0.33.1",
|
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"onnx>=1.21.0",
|
|
||||||
"onnxscript>=0.6.2",
|
|
||||||
"opencv-python>=4.13.0.92",
|
"opencv-python>=4.13.0.92",
|
||||||
"tensorrt>=10.16.1.11",
|
"torch>=2.11.0",
|
||||||
"torch==2.5.1",
|
|
||||||
"torch-tensorrt>=2.5.0",
|
|
||||||
"torchvision>=0.20.1",
|
|
||||||
"tqdm>=4.67.3",
|
"tqdm>=4.67.3",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import Literal
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -7,7 +6,6 @@ from dataclasses import dataclass
|
|||||||
class Preset:
|
class Preset:
|
||||||
config: Path
|
config: Path
|
||||||
checkpoint: Path
|
checkpoint: Path
|
||||||
onnx: Path | None = None
|
|
||||||
|
|
||||||
|
|
||||||
SMALL = Preset(
|
SMALL = Preset(
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
torchvision.disable_beta_transforms_warning()
|
|
||||||
torch.backends.cudnn.enabled = False
|
|
||||||
from .interpolator import ModelRunner
|
|
||||||
from .config.presets import Preset
|
|
||||||
|
|
||||||
|
|
||||||
def export_to_onnx(preset: Preset, output_path: str, device: torch.device):
|
|
||||||
model_runner = ModelRunner(preset, device)
|
|
||||||
# model_runner.model.eval()
|
|
||||||
dummy_input = model_runner.get_dummy_input()
|
|
||||||
torch.onnx.export(
|
|
||||||
model_runner.model,
|
|
||||||
dummy_input,
|
|
||||||
output_path,
|
|
||||||
opset_version=17,
|
|
||||||
input_names=['img0', 'img1', 'embt'],
|
|
||||||
output_names=["imgt_pred"],
|
|
||||||
dynamic_axes={
|
|
||||||
"img0": {0: "batch", 2: "height", 3: "width"},
|
|
||||||
"img1": {0: "batch", 2: "height", 3: "width"},
|
|
||||||
"embt": {0: "batch", 2: "height", 3: "width"},
|
|
||||||
"imgt_pred": {0: "batch", 2: "height", 3: "width"},
|
|
||||||
},
|
|
||||||
dynamo=True,
|
|
||||||
use_external_data_format=False,
|
|
||||||
)
|
|
||||||
@@ -1,16 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from cv2 import imread
|
|
||||||
import torch
|
import torch
|
||||||
import onnxruntime as ort
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
from imageio import imread, imwrite
|
||||||
|
|
||||||
from src.config.presets import Preset
|
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||||
from src.utils.torch import img2tensor, tensor2img
|
|
||||||
from src.utils.build import build_from_cfg
|
from src.utils.build import build_from_cfg
|
||||||
|
from src.utils.padder import InputPadder
|
||||||
|
|
||||||
|
|
||||||
class Anchor:
|
class Anchor:
|
||||||
@@ -23,27 +21,8 @@ class Anchor:
|
|||||||
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
|
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
|
||||||
|
|
||||||
|
|
||||||
class ONNXWrapper:
|
|
||||||
def __init__(self, path):
|
|
||||||
self.session = ort.InferenceSession(path)
|
|
||||||
|
|
||||||
self.input_names = [i.name for i in self.session.get_inputs()]
|
|
||||||
self.output_names = [o.name for o in self.session.get_outputs()]
|
|
||||||
|
|
||||||
def __call__(self, tensor1, tensor2, embt):
|
|
||||||
inputs = {
|
|
||||||
self.input_names[0]: tensor1.cpu().numpy(),
|
|
||||||
self.input_names[1]: tensor2.cpu().numpy(),
|
|
||||||
self.input_names[2]: embt.cpu().numpy(),
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs = self.session.run(self.output_names, inputs)
|
|
||||||
|
|
||||||
return {"imgt_pred": torch.from_numpy(outputs[0])}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
def __init__(self, preset: Preset, device: torch.device) -> None:
|
def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
|
||||||
"""Initializes the ModelRunner with configuration and checkpoint.
|
"""Initializes the ModelRunner with configuration and checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -51,73 +30,17 @@ class ModelRunner:
|
|||||||
ckpt_path (Path): Path to model checkpoint in .pth format
|
ckpt_path (Path): Path to model checkpoint in .pth format
|
||||||
device (torch.device): Device to load the model on
|
device (torch.device): Device to load the model on
|
||||||
"""
|
"""
|
||||||
self.model: Optional[torch.nn.Module] = None
|
omega_config = OmegaConf.load(config)
|
||||||
self.session: Optional[ONNXWrapper] = None
|
|
||||||
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
|
||||||
|
|
||||||
if preset.onnx:
|
|
||||||
self.session = ONNXWrapper(preset.onnx)
|
|
||||||
self.device = device
|
|
||||||
self.embt = self.embt.cpu().numpy()
|
|
||||||
return
|
|
||||||
|
|
||||||
omega_config = OmegaConf.load(preset.config)
|
|
||||||
network_config: DictConfig = omega_config.network
|
network_config: DictConfig = omega_config.network
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Loaded network configuration: {network_config} from [{preset.checkpoint}]"
|
f"Loaded network configuration: {network_config} from [{ckpt_path}]"
|
||||||
)
|
)
|
||||||
model = build_from_cfg(network_config)
|
model = build_from_cfg(network_config)
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||||
preset.checkpoint, map_location=device, weights_only=False
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
model = model.to(device)
|
model = model.to(get_device())
|
||||||
model.eval()
|
model.eval()
|
||||||
# self.model = torch.compile(model)
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
|
||||||
# self.model = torch.compile(model, backend="tensorrt")
|
|
||||||
if logging.getLogger().isEnabledFor(logging.DEBUG):
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
logging.debug(
|
|
||||||
f"Parameter: {name}, shape: {param.shape}, dtype: {param.dtype}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_dummy_input(self):
|
|
||||||
"""Generates a dummy input tensor for ONNX export."""
|
|
||||||
return (
|
|
||||||
img2tensor(imread(filename="example/frame_01.png"), self.device),
|
|
||||||
img2tensor(imread(filename="example/frame_02.png"), self.device),
|
|
||||||
self.embt,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
|
|
||||||
"""Runs the model inference to interpolate between two images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image1 (np.ndarray): First input image as a NumPy array
|
|
||||||
image2 (np.ndarray): Second input image as a NumPy array
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Interpolated image as a NumPy array
|
|
||||||
"""
|
|
||||||
if self.session:
|
|
||||||
image1 = img2tensor(image1, self.device).cpu().numpy()
|
|
||||||
image2 = img2tensor(image2, self.device).cpu().numpy()
|
|
||||||
inputs = {
|
|
||||||
"img0": image1,
|
|
||||||
"img1": image2,
|
|
||||||
"embt": self.embt,
|
|
||||||
}
|
|
||||||
outputs = self.session.session.run(None, inputs)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
tensor1 = img2tensor(image1, self.device)
|
|
||||||
tensor2 = img2tensor(image2, self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
with torch.amp.autocast(self.device.type):
|
|
||||||
interpolated = self.model(tensor1, tensor2, self.embt)["imgt_pred"]
|
|
||||||
return tensor2img(interpolated.cpu())
|
|
||||||
|
|
||||||
|
|
||||||
def get_vram_available(device: torch.device) -> int:
|
def get_vram_available(device: torch.device) -> int:
|
||||||
@@ -154,12 +77,13 @@ class ImageInterpolator:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.anchor = anchor
|
self.anchor = anchor
|
||||||
self.vram_available = get_vram_available(device)
|
self.vram_available = get_vram_available(device)
|
||||||
|
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
|
def interpolate(self, image1: Path, image2: Path, output_path: Path):
|
||||||
"""
|
"""
|
||||||
Interpolates between two images and saves the result.
|
Interpolates between two images and saves the result.
|
||||||
Args:
|
Args:
|
||||||
@@ -167,7 +91,39 @@ class ImageInterpolator:
|
|||||||
image2 (Path): Path to the second input image (only png and jpg formats are supported)
|
image2 (Path): Path to the second input image (only png and jpg formats are supported)
|
||||||
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
|
||||||
"""
|
"""
|
||||||
return self.model_runner.run(image1, image2)
|
logging.debug(f"Reading images: {image1} and {image2}")
|
||||||
|
tensor1 = img2tensor(imread(image1)).to(self.device)
|
||||||
|
tensor2 = img2tensor(imread(image2)).to(self.device)
|
||||||
|
logging.debug(
|
||||||
|
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
|
||||||
|
)
|
||||||
|
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
|
||||||
|
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
|
||||||
|
h, w = tensor1.shape[2], tensor1.shape[3]
|
||||||
|
logging.debug(f"Interpolating images of size: {h}x{w}")
|
||||||
|
|
||||||
|
scale = self.scale(h, w)
|
||||||
|
logging.debug(f"Calculated scale factor: {scale:.2f}")
|
||||||
|
padding = int(16 / scale)
|
||||||
|
logging.debug(f"Calculated padding: {padding} pixels")
|
||||||
|
padder = InputPadder(tensor1.shape, divisor=padding)
|
||||||
|
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
|
||||||
|
logging.debug(
|
||||||
|
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor1_padded = tensor1_padded.to(self.device)
|
||||||
|
tensor2_padded = tensor2_padded.to(self.device)
|
||||||
|
logging.debug("Running model inference for interpolation")
|
||||||
|
with torch.no_grad():
|
||||||
|
interpolated = self.model_runner.model(
|
||||||
|
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
|
||||||
|
)["imgt_pred"]
|
||||||
|
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
|
||||||
|
(interpolated,) = padder.unpad(interpolated)
|
||||||
|
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
|
||||||
|
imwrite(output_path, tensor2img(interpolated.cpu()))
|
||||||
|
logging.debug(f"Saved interpolated image to: {output_path}")
|
||||||
|
|
||||||
def scale(self, height: int, width: int) -> float:
|
def scale(self, height: int, width: int) -> float:
|
||||||
scale = (
|
scale = (
|
||||||
|
|||||||
@@ -67,14 +67,7 @@ class Model(nn.Module):
|
|||||||
flow = torch.cat([flow0, flow1], dim=1)
|
flow = torch.cat([flow0, flow1], dim=1)
|
||||||
return corr, flow
|
return corr, flow
|
||||||
|
|
||||||
def forward(
|
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
||||||
self,
|
|
||||||
img0: torch.Tensor,
|
|
||||||
img1: torch.Tensor,
|
|
||||||
embt: torch.Tensor,
|
|
||||||
):
|
|
||||||
scale_factor = 1.0
|
|
||||||
eval = False
|
|
||||||
mean_ = (
|
mean_ = (
|
||||||
torch.cat([img0, img1], 2)
|
torch.cat([img0, img1], 2)
|
||||||
.mean(1, keepdim=True)
|
.mean(1, keepdim=True)
|
||||||
|
|||||||
@@ -1,44 +1,40 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class BottleneckBlock(nn.Module):
|
class BottleneckBlock(nn.Module):
|
||||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||||
super(BottleneckBlock, self).__init__()
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||||
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||||
)
|
|
||||||
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
num_groups = planes // 8
|
num_groups = planes // 8
|
||||||
|
|
||||||
if norm_fn == "group":
|
if norm_fn == 'group':
|
||||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
elif norm_fn == "batch":
|
elif norm_fn == 'batch':
|
||||||
self.norm1 = nn.BatchNorm2d(planes // 4)
|
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||||
self.norm2 = nn.BatchNorm2d(planes // 4)
|
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||||
self.norm3 = nn.BatchNorm2d(planes)
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm4 = nn.BatchNorm2d(planes)
|
self.norm4 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
elif norm_fn == "instance":
|
elif norm_fn == 'instance':
|
||||||
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||||
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||||
self.norm3 = nn.InstanceNorm2d(planes)
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm4 = nn.InstanceNorm2d(planes)
|
self.norm4 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
elif norm_fn == "none":
|
elif norm_fn == 'none':
|
||||||
self.norm1 = nn.Sequential()
|
self.norm1 = nn.Sequential()
|
||||||
self.norm2 = nn.Sequential()
|
self.norm2 = nn.Sequential()
|
||||||
self.norm3 = nn.Sequential()
|
self.norm3 = nn.Sequential()
|
||||||
@@ -47,11 +43,11 @@ class BottleneckBlock(nn.Module):
|
|||||||
|
|
||||||
if stride == 1:
|
if stride == 1:
|
||||||
self.downsample = None
|
self.downsample = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
@@ -62,40 +58,38 @@ class BottleneckBlock(nn.Module):
|
|||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
|
||||||
return self.relu(x + y)
|
return self.relu(x+y)
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||||
super(ResidualBlock, self).__init__()
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||||
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
|
||||||
)
|
|
||||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
num_groups = planes // 8
|
num_groups = planes // 8
|
||||||
|
|
||||||
if norm_fn == "group":
|
if norm_fn == 'group':
|
||||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
elif norm_fn == "batch":
|
elif norm_fn == 'batch':
|
||||||
self.norm1 = nn.BatchNorm2d(planes)
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
self.norm2 = nn.BatchNorm2d(planes)
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.BatchNorm2d(planes)
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
elif norm_fn == "instance":
|
elif norm_fn == 'instance':
|
||||||
self.norm1 = nn.InstanceNorm2d(planes)
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
self.norm2 = nn.InstanceNorm2d(planes)
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.InstanceNorm2d(planes)
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
elif norm_fn == "none":
|
elif norm_fn == 'none':
|
||||||
self.norm1 = nn.Sequential()
|
self.norm1 = nn.Sequential()
|
||||||
self.norm2 = nn.Sequential()
|
self.norm2 = nn.Sequential()
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
@@ -103,11 +97,11 @@ class ResidualBlock(nn.Module):
|
|||||||
|
|
||||||
if stride == 1:
|
if stride == 1:
|
||||||
self.downsample = None
|
self.downsample = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
@@ -117,43 +111,43 @@ class ResidualBlock(nn.Module):
|
|||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
|
||||||
return self.relu(x + y)
|
return self.relu(x+y)
|
||||||
|
|
||||||
|
|
||||||
class SmallEncoder(nn.Module):
|
class SmallEncoder(nn.Module):
|
||||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||||
super(SmallEncoder, self).__init__()
|
super(SmallEncoder, self).__init__()
|
||||||
self.norm_fn = norm_fn
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
if self.norm_fn == "group":
|
if self.norm_fn == 'group':
|
||||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||||
|
|
||||||
elif self.norm_fn == "batch":
|
elif self.norm_fn == 'batch':
|
||||||
self.norm1 = nn.BatchNorm2d(32)
|
self.norm1 = nn.BatchNorm2d(32)
|
||||||
|
|
||||||
elif self.norm_fn == "instance":
|
elif self.norm_fn == 'instance':
|
||||||
self.norm1 = nn.InstanceNorm2d(32)
|
self.norm1 = nn.InstanceNorm2d(32)
|
||||||
|
|
||||||
elif self.norm_fn == "none":
|
elif self.norm_fn == 'none':
|
||||||
self.norm1 = nn.Sequential()
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||||
self.relu1 = nn.ReLU(inplace=True)
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.in_planes = 32
|
self.in_planes = 32
|
||||||
self.layer1 = self._make_layer(32, stride=1)
|
self.layer1 = self._make_layer(32, stride=1)
|
||||||
self.layer2 = self._make_layer(64, stride=2)
|
self.layer2 = self._make_layer(64, stride=2)
|
||||||
self.layer3 = self._make_layer(96, stride=2)
|
self.layer3 = self._make_layer(96, stride=2)
|
||||||
|
|
||||||
self.dropout = None
|
self.dropout = None
|
||||||
if dropout > 0:
|
if dropout > 0:
|
||||||
self.dropout = nn.Dropout2d(p=dropout)
|
self.dropout = nn.Dropout2d(p=dropout)
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
if m.weight is not None:
|
if m.weight is not None:
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
@@ -164,19 +158,18 @@ class SmallEncoder(nn.Module):
|
|||||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
layers = (layer1, layer2)
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
self.in_planes = dim
|
self.in_planes = dim
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]
|
def forward(self, x):
|
||||||
):
|
|
||||||
|
|
||||||
# if input is list, combine batch dimension
|
# if input is list, combine batch dimension
|
||||||
batch_dim = None
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
if is_list := isinstance(x, tuple) or isinstance(x, list):
|
if is_list:
|
||||||
batch_dim = x[0].shape[0]
|
batch_dim = x[0].shape[0]
|
||||||
x: torch.Tensor = torch.cat(x, dim=0)
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
@@ -190,37 +183,33 @@ class SmallEncoder(nn.Module):
|
|||||||
if self.training and self.dropout is not None:
|
if self.training and self.dropout is not None:
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
if is_list and batch_dim is not None:
|
if is_list:
|
||||||
return torch.split(x, [batch_dim, batch_dim], dim=0)
|
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, *args: Any, **kwds: Any) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
return super().__call__(*args, **kwds)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicEncoder(nn.Module):
|
class BasicEncoder(nn.Module):
|
||||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||||
super(BasicEncoder, self).__init__()
|
super(BasicEncoder, self).__init__()
|
||||||
self.norm_fn = norm_fn
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
if self.norm_fn == "group":
|
if self.norm_fn == 'group':
|
||||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||||
|
|
||||||
elif self.norm_fn == "batch":
|
elif self.norm_fn == 'batch':
|
||||||
self.norm1 = nn.BatchNorm2d(64)
|
self.norm1 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
elif self.norm_fn == "instance":
|
elif self.norm_fn == 'instance':
|
||||||
self.norm1 = nn.InstanceNorm2d(64)
|
self.norm1 = nn.InstanceNorm2d(64)
|
||||||
|
|
||||||
elif self.norm_fn == "none":
|
elif self.norm_fn == 'none':
|
||||||
self.norm1 = nn.Sequential()
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||||
self.relu1 = nn.ReLU(inplace=True)
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.in_planes = 64
|
self.in_planes = 64
|
||||||
self.layer1 = self._make_layer(64, stride=1)
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
self.layer2 = self._make_layer(72, stride=2)
|
self.layer2 = self._make_layer(72, stride=2)
|
||||||
self.layer3 = self._make_layer(128, stride=2)
|
self.layer3 = self._make_layer(128, stride=2)
|
||||||
|
|
||||||
@@ -233,7 +222,7 @@ class BasicEncoder(nn.Module):
|
|||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
if m.weight is not None:
|
if m.weight is not None:
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
@@ -244,10 +233,11 @@ class BasicEncoder(nn.Module):
|
|||||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
layers = (layer1, layer2)
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
self.in_planes = dim
|
self.in_planes = dim
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
# if input is list, combine batch dimension
|
# if input is list, combine batch dimension
|
||||||
@@ -274,22 +264,21 @@ class BasicEncoder(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LargeEncoder(nn.Module):
|
class LargeEncoder(nn.Module):
|
||||||
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||||
super(LargeEncoder, self).__init__()
|
super(LargeEncoder, self).__init__()
|
||||||
self.norm_fn = norm_fn
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
if self.norm_fn == "group":
|
if self.norm_fn == 'group':
|
||||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||||
|
|
||||||
elif self.norm_fn == "batch":
|
elif self.norm_fn == 'batch':
|
||||||
self.norm1 = nn.BatchNorm2d(64)
|
self.norm1 = nn.BatchNorm2d(64)
|
||||||
|
|
||||||
elif self.norm_fn == "instance":
|
elif self.norm_fn == 'instance':
|
||||||
self.norm1 = nn.InstanceNorm2d(64)
|
self.norm1 = nn.InstanceNorm2d(64)
|
||||||
|
|
||||||
elif self.norm_fn == "none":
|
elif self.norm_fn == 'none':
|
||||||
self.norm1 = nn.Sequential()
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||||
@@ -310,7 +299,7 @@ class LargeEncoder(nn.Module):
|
|||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
if m.weight is not None:
|
if m.weight is not None:
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
@@ -321,10 +310,11 @@ class LargeEncoder(nn.Module):
|
|||||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
layers = (layer1, layer2)
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
self.in_planes = dim
|
self.in_planes = dim
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
# if input is list, combine batch dimension
|
# if input is list, combine batch dimension
|
||||||
|
|||||||
@@ -4,131 +4,85 @@ import torch.nn.functional as F
|
|||||||
from src.utils.flow_utils import warp
|
from src.utils.flow_utils import warp
|
||||||
|
|
||||||
|
|
||||||
def resize(x: torch.Tensor, scale_factor: float) -> torch.Tensor:
|
def resize(x, scale_factor):
|
||||||
return F.interpolate(
|
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
||||||
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
|
||||||
def convrelu(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
):
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
|
||||||
in_channels,
|
nn.PReLU(out_channels)
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
padding,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.PReLU(out_channels),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
def __init__(self, in_channels, side_channels, bias=True):
|
def __init__(self, in_channels, side_channels, bias=True):
|
||||||
super(ResBlock, self).__init__()
|
super(ResBlock, self).__init__()
|
||||||
self.side_channels = side_channels
|
self.side_channels = side_channels
|
||||||
self.conv1 = nn.Sequential(
|
self.conv1 = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
nn.PReLU(in_channels)
|
||||||
),
|
|
||||||
nn.PReLU(in_channels),
|
|
||||||
)
|
)
|
||||||
self.conv2 = nn.Sequential(
|
self.conv2 = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||||
side_channels,
|
nn.PReLU(side_channels)
|
||||||
side_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.PReLU(side_channels),
|
|
||||||
)
|
)
|
||||||
self.conv3 = nn.Sequential(
|
self.conv3 = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
nn.PReLU(in_channels)
|
||||||
),
|
|
||||||
nn.PReLU(in_channels),
|
|
||||||
)
|
)
|
||||||
self.conv4 = nn.Sequential(
|
self.conv4 = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||||
side_channels,
|
nn.PReLU(side_channels)
|
||||||
side_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=bias,
|
|
||||||
),
|
|
||||||
nn.PReLU(side_channels),
|
|
||||||
)
|
|
||||||
self.conv5 = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
|
||||||
)
|
)
|
||||||
|
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
||||||
self.prelu = nn.PReLU(in_channels)
|
self.prelu = nn.PReLU(in_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
|
|
||||||
res_feat = out[:, : -self.side_channels, ...]
|
res_feat = out[:, :-self.side_channels, ...]
|
||||||
side_feat = out[:, -self.side_channels :, :, :]
|
side_feat = out[:, -self.side_channels:, :, :]
|
||||||
side_feat = self.conv2(side_feat)
|
side_feat = self.conv2(side_feat)
|
||||||
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||||
|
|
||||||
res_feat = out[:, : -self.side_channels, ...]
|
res_feat = out[:, :-self.side_channels, ...]
|
||||||
side_feat = out[:, -self.side_channels :, :, :]
|
side_feat = out[:, -self.side_channels:, :, :]
|
||||||
side_feat = self.conv4(side_feat)
|
side_feat = self.conv4(side_feat)
|
||||||
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||||
|
|
||||||
out = self.prelu(x + out)
|
out = self.prelu(x + out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, channels, large=False):
|
def __init__(self, channels, large=False):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
prev_ch = 3
|
prev_ch = 3
|
||||||
for idx, ch in enumerate(channels, 1):
|
for idx, ch in enumerate(channels, 1):
|
||||||
k = 7 if large and idx == 1 else 3
|
k = 7 if large and idx == 1 else 3
|
||||||
p = 3 if k == 7 else 1
|
p = 3 if k ==7 else 1
|
||||||
self.register_module(
|
self.register_module(f'pyramid{idx}',
|
||||||
f"pyramid{idx}",
|
nn.Sequential(
|
||||||
nn.Sequential(
|
convrelu(prev_ch, ch, k, 2, p),
|
||||||
convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1)
|
convrelu(ch, ch, 3, 1, 1)
|
||||||
),
|
))
|
||||||
)
|
|
||||||
prev_ch = ch
|
prev_ch = ch
|
||||||
|
|
||||||
def forward(self, in_x):
|
def forward(self, in_x):
|
||||||
fs = []
|
fs = []
|
||||||
for idx in range(len(self.channels)):
|
for idx in range(len(self.channels)):
|
||||||
out_x = getattr(self, f"pyramid{idx + 1}")(in_x)
|
out_x = getattr(self, f'pyramid{idx+1}')(in_x)
|
||||||
fs.append(out_x)
|
fs.append(out_x)
|
||||||
in_x = out_x
|
in_x = out_x
|
||||||
return fs
|
return fs
|
||||||
|
|
||||||
|
|
||||||
class InitDecoder(nn.Module):
|
class InitDecoder(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convblock = nn.Sequential(
|
self.convblock = nn.Sequential(
|
||||||
convrelu(in_ch * 2 + 1, in_ch * 2),
|
convrelu(in_ch*2+1, in_ch*2),
|
||||||
ResBlock(in_ch * 2, skip_ch),
|
ResBlock(in_ch*2, skip_ch),
|
||||||
nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True),
|
nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, f0, f1, embt):
|
def forward(self, f0, f1, embt):
|
||||||
h, w = f0.shape[2:]
|
h, w = f0.shape[2:]
|
||||||
embt = embt.repeat(1, 1, h, w)
|
embt = embt.repeat(1, 1, h, w)
|
||||||
@@ -136,17 +90,15 @@ class InitDecoder(nn.Module):
|
|||||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||||
ft_ = out[:, 4:, ...]
|
ft_ = out[:, 4:, ...]
|
||||||
return flow0, flow1, ft_
|
return flow0, flow1, ft_
|
||||||
|
|
||||||
|
|
||||||
class IntermediateDecoder(nn.Module):
|
class IntermediateDecoder(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convblock = nn.Sequential(
|
self.convblock = nn.Sequential(
|
||||||
convrelu(in_ch * 3 + 4, in_ch * 3),
|
convrelu(in_ch*3+4, in_ch*3),
|
||||||
ResBlock(in_ch * 3, skip_ch),
|
ResBlock(in_ch*3, skip_ch),
|
||||||
nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1, bias=True),
|
nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
|
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
|
||||||
f0_warp = warp(f0, flow0_in)
|
f0_warp = warp(f0, flow0_in)
|
||||||
f1_warp = warp(f1, flow1_in)
|
f1_warp = warp(f1, flow1_in)
|
||||||
@@ -156,4 +108,4 @@ class IntermediateDecoder(nn.Module):
|
|||||||
ft_ = out[:, 4:, ...]
|
ft_ = out[:, 4:, ...]
|
||||||
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
|
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
|
||||||
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
|
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
|
||||||
return flow0, flow1, ft_
|
return flow0, flow1, ft_
|
||||||
@@ -4,17 +4,15 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
def resize(x, scale_factor):
|
def resize(x, scale_factor):
|
||||||
return F.interpolate(
|
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
||||||
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bilinear_sampler(img: torch.Tensor, coords: torch.Tensor, mask=False):
|
def bilinear_sampler(img, coords, mask=False):
|
||||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||||
H, W = img.shape[-2:]
|
H, W = img.shape[-2:]
|
||||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||||
xgrid = 2 * xgrid / (W - 1) - 1
|
xgrid = 2*xgrid/(W-1) - 1
|
||||||
ygrid = 2 * ygrid / (H - 1) - 1
|
ygrid = 2*ygrid/(H-1) - 1
|
||||||
|
|
||||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||||
img = F.grid_sample(img, grid, align_corners=True)
|
img = F.grid_sample(img, grid, align_corners=True)
|
||||||
@@ -27,36 +25,27 @@ def bilinear_sampler(img: torch.Tensor, coords: torch.Tensor, mask=False):
|
|||||||
|
|
||||||
|
|
||||||
def coords_grid(batch, ht, wd, device):
|
def coords_grid(batch, ht, wd, device):
|
||||||
coords = torch.meshgrid(
|
coords = torch.meshgrid(torch.arange(ht, device=device),
|
||||||
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
|
torch.arange(wd, device=device),
|
||||||
)
|
indexing='ij')
|
||||||
coords = torch.stack(coords[::-1], dim=0).float()
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
return coords[None].repeat(batch, 1, 1, 1)
|
return coords[None].repeat(batch, 1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
class SmallUpdateBlock(nn.Module):
|
class SmallUpdateBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
|
||||||
self,
|
corr_levels=4, radius=3, scale_factor=None):
|
||||||
cdim,
|
|
||||||
hidden_dim,
|
|
||||||
flow_dim,
|
|
||||||
corr_dim,
|
|
||||||
fc_dim,
|
|
||||||
corr_levels=4,
|
|
||||||
radius=3,
|
|
||||||
scale_factor=None,
|
|
||||||
):
|
|
||||||
super(SmallUpdateBlock, self).__init__()
|
super(SmallUpdateBlock, self).__init__()
|
||||||
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
cor_planes = corr_levels * (2 * radius + 1) **2
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||||
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
||||||
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
||||||
self.conv = nn.Conv2d(corr_dim + flow_dim, fc_dim, 3, padding=1)
|
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
|
||||||
|
|
||||||
self.gru = nn.Sequential(
|
self.gru = nn.Sequential(
|
||||||
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
||||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
)
|
)
|
||||||
@@ -74,11 +63,10 @@ class SmallUpdateBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
def forward(self, net, flow, corr):
|
def forward(self, net, flow, corr):
|
||||||
net = (
|
net = resize(net, 1 / self.scale_factor
|
||||||
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
) if self.scale_factor is not None else net
|
||||||
)
|
|
||||||
cor = self.lrelu(self.convc1(corr))
|
cor = self.lrelu(self.convc1(corr))
|
||||||
flo = self.lrelu(self.convf1(flow))
|
flo = self.lrelu(self.convf1(flow))
|
||||||
flo = self.lrelu(self.convf2(flo))
|
flo = self.lrelu(self.convf2(flo))
|
||||||
@@ -89,42 +77,29 @@ class SmallUpdateBlock(nn.Module):
|
|||||||
out = self.gru(inp)
|
out = self.gru(inp)
|
||||||
delta_net = self.feat_head(out)
|
delta_net = self.feat_head(out)
|
||||||
delta_flow = self.flow_head(out)
|
delta_flow = self.flow_head(out)
|
||||||
|
|
||||||
if self.scale_factor is not None:
|
if self.scale_factor is not None:
|
||||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||||
delta_flow = self.scale_factor * resize(
|
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
||||||
delta_flow, scale_factor=self.scale_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
return delta_net, delta_flow
|
return delta_net, delta_flow
|
||||||
|
|
||||||
|
|
||||||
class BasicUpdateBlock(nn.Module):
|
class BasicUpdateBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
|
||||||
self,
|
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
|
||||||
cdim,
|
|
||||||
hidden_dim,
|
|
||||||
flow_dim,
|
|
||||||
corr_dim,
|
|
||||||
corr_dim2,
|
|
||||||
fc_dim,
|
|
||||||
corr_levels=4,
|
|
||||||
radius=3,
|
|
||||||
scale_factor=None,
|
|
||||||
out_num=1,
|
|
||||||
):
|
|
||||||
super(BasicUpdateBlock, self).__init__()
|
super(BasicUpdateBlock, self).__init__()
|
||||||
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
cor_planes = corr_levels * (2 * radius + 1) **2
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||||
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
||||||
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
||||||
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
||||||
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
|
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
|
||||||
|
|
||||||
self.gru = nn.Sequential(
|
self.gru = nn.Sequential(
|
||||||
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
||||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
)
|
)
|
||||||
@@ -138,15 +113,14 @@ class BasicUpdateBlock(nn.Module):
|
|||||||
self.flow_head = nn.Sequential(
|
self.flow_head = nn.Sequential(
|
||||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
|
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
|
||||||
def forward(self, net, flow, corr):
|
def forward(self, net, flow, corr):
|
||||||
net = (
|
net = resize(net, 1 / self.scale_factor
|
||||||
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
) if self.scale_factor is not None else net
|
||||||
)
|
|
||||||
cor = self.lrelu(self.convc1(corr))
|
cor = self.lrelu(self.convc1(corr))
|
||||||
cor = self.lrelu(self.convc2(cor))
|
cor = self.lrelu(self.convc2(cor))
|
||||||
flo = self.lrelu(self.convf1(flow))
|
flo = self.lrelu(self.convf1(flow))
|
||||||
@@ -158,47 +132,41 @@ class BasicUpdateBlock(nn.Module):
|
|||||||
out = self.gru(inp)
|
out = self.gru(inp)
|
||||||
delta_net = self.feat_head(out)
|
delta_net = self.feat_head(out)
|
||||||
delta_flow = self.flow_head(out)
|
delta_flow = self.flow_head(out)
|
||||||
|
|
||||||
if self.scale_factor is not None:
|
if self.scale_factor is not None:
|
||||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||||
delta_flow = self.scale_factor * resize(
|
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
||||||
delta_flow, scale_factor=self.scale_factor
|
|
||||||
)
|
|
||||||
return delta_net, delta_flow
|
return delta_net, delta_flow
|
||||||
|
|
||||||
|
|
||||||
class BidirCorrBlock:
|
class BidirCorrBlock:
|
||||||
def __init__(
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
self, fmap1: torch.Tensor, fmap2: torch.Tensor, num_levels=4, radius=4
|
|
||||||
):
|
|
||||||
self.num_levels = num_levels
|
self.num_levels = num_levels
|
||||||
self.radius = radius
|
self.radius = radius
|
||||||
self.corr_pyramid: list[torch.Tensor] = []
|
self.corr_pyramid = []
|
||||||
self.corr_pyramid_T: list[torch.Tensor] = []
|
self.corr_pyramid_T = []
|
||||||
|
|
||||||
corr = BidirCorrBlock.corr(fmap1, fmap2)
|
corr = BidirCorrBlock.corr(fmap1, fmap2)
|
||||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||||
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
|
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
|
||||||
|
|
||||||
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||||
corr_T = corr_T.reshape(batch * h2 * w2, dim, h1, w1)
|
corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
|
||||||
|
|
||||||
self.corr_pyramid.append(corr)
|
self.corr_pyramid.append(corr)
|
||||||
self.corr_pyramid_T.append(corr_T)
|
self.corr_pyramid_T.append(corr_T)
|
||||||
|
|
||||||
for _ in range(self.num_levels - 1):
|
for _ in range(self.num_levels-1):
|
||||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||||
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
|
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
|
||||||
self.corr_pyramid.append(corr)
|
self.corr_pyramid.append(corr)
|
||||||
self.corr_pyramid_T.append(corr_T)
|
self.corr_pyramid_T.append(corr_T)
|
||||||
|
|
||||||
def __call__(self, coords0: torch.Tensor, coords1: torch.Tensor):
|
def __call__(self, coords0, coords1):
|
||||||
r = self.radius
|
r = self.radius
|
||||||
coords0 = coords0.permute(0, 2, 3, 1)
|
coords0 = coords0.permute(0, 2, 3, 1)
|
||||||
coords1 = coords1.permute(0, 2, 3, 1)
|
coords1 = coords1.permute(0, 2, 3, 1)
|
||||||
assert coords0.shape == coords1.shape, (
|
assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
||||||
f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
|
||||||
)
|
|
||||||
batch, h1, w1, _ = coords0.shape
|
batch, h1, w1, _ = coords0.shape
|
||||||
|
|
||||||
out_pyramid = []
|
out_pyramid = []
|
||||||
@@ -207,15 +175,15 @@ class BidirCorrBlock:
|
|||||||
corr = self.corr_pyramid[i]
|
corr = self.corr_pyramid[i]
|
||||||
corr_T = self.corr_pyramid_T[i]
|
corr_T = self.corr_pyramid_T[i]
|
||||||
|
|
||||||
dx = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
|
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
||||||
dy = torch.linspace(-r, r, 2 * r + 1, device=coords0.device)
|
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
||||||
delta: torch.Tensor = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
|
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
|
||||||
delta_lvl: torch.Tensor = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||||
|
|
||||||
centroid_lvl_0: torch.Tensor = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||||
centroid_lvl_1: torch.Tensor = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||||
coords_lvl_0: torch.Tensor = centroid_lvl_0 + delta_lvl
|
coords_lvl_0 = centroid_lvl_0 + delta_lvl
|
||||||
coords_lvl_1: torch.Tensor = centroid_lvl_1 + delta_lvl
|
coords_lvl_1 = centroid_lvl_1 + delta_lvl
|
||||||
|
|
||||||
corr = bilinear_sampler(corr, coords_lvl_0)
|
corr = bilinear_sampler(corr, coords_lvl_0)
|
||||||
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
|
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
|
||||||
@@ -226,16 +194,14 @@ class BidirCorrBlock:
|
|||||||
|
|
||||||
out = torch.cat(out_pyramid, dim=-1)
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
out_T = torch.cat(out_pyramid_T, dim=-1)
|
out_T = torch.cat(out_pyramid_T, dim=-1)
|
||||||
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(
|
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
|
||||||
0, 3, 1, 2
|
|
||||||
).contiguous().float()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def corr(fmap1: torch.Tensor, fmap2: torch.Tensor):
|
def corr(fmap1, fmap2):
|
||||||
batch, dim, ht, wd = fmap1.shape
|
batch, dim, ht, wd = fmap1.shape
|
||||||
fmap1 = fmap1.view(batch, dim, ht * wd)
|
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||||
fmap2 = fmap2.view(batch, dim, ht * wd)
|
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||||
|
|
||||||
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||||
return corr * (dim**-0.5)
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
@@ -5,26 +5,23 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def tensor2img(tensor: torch.Tensor):
|
def tensor2img(tensor: torch.Tensor):
|
||||||
tensor = (
|
return (
|
||||||
tensor.mul(255.0)
|
(tensor * 255.0)
|
||||||
.clamp_(0, 255)
|
.detach()
|
||||||
.to(torch.uint8)
|
|
||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
.permute(1, 2, 0)
|
.permute(1, 2, 0)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.clip(0, 255)
|
||||||
|
.astype(np.uint8)
|
||||||
)
|
)
|
||||||
|
|
||||||
return tensor.cpu().numpy()
|
|
||||||
|
|
||||||
|
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
||||||
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
|
|
||||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||||
if img.shape[-1] > 3:
|
if img.shape[-1] > 3:
|
||||||
img = img[:, :, :3]
|
img = img[:, :, :3]
|
||||||
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
|
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||||
if device.type != "cuda":
|
|
||||||
return tensor.float() / 255.0
|
|
||||||
|
|
||||||
return tensor.cuda(non_blocking=True).float().div_(255.0)
|
|
||||||
|
|
||||||
|
|
||||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, Iterable
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
from typing import Generator
|
||||||
|
|
||||||
|
|
||||||
class VideoMaker:
|
class VideoMaker:
|
||||||
@@ -36,7 +35,7 @@ class VideoMaker:
|
|||||||
with open(file, "w") as f:
|
with open(file, "w") as f:
|
||||||
for video in videos:
|
for video in videos:
|
||||||
f.write(f"file '{video}'\n")
|
f.write(f"file '{video}'\n")
|
||||||
cmd = f"ffmpeg -y -f concat -safe 0 -i {file} -c copy {output_path}"
|
cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}"
|
||||||
logging.info(f"Running command: {cmd}")
|
logging.info(f"Running command: {cmd}")
|
||||||
result = self.run_command(cmd)
|
result = self.run_command(cmd)
|
||||||
if result != 0:
|
if result != 0:
|
||||||
@@ -67,13 +66,7 @@ class VideoMaker:
|
|||||||
|
|
||||||
def run_command(self, cmd: str) -> int:
|
def run_command(self, cmd: str) -> int:
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
cmd,
|
|
||||||
shell=True,
|
|
||||||
check=True,
|
|
||||||
stdout=subprocess.DEVNULL,
|
|
||||||
stderr=subprocess.DEVNULL,
|
|
||||||
)
|
|
||||||
return 0
|
return 0
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logging.error(f"Command failed with error: {e}")
|
logging.error(f"Command failed with error: {e}")
|
||||||
@@ -81,7 +74,7 @@ class VideoMaker:
|
|||||||
|
|
||||||
def video_to_frames_generator(
|
def video_to_frames_generator(
|
||||||
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
|
self, video_path: Path, output_dir: Path, chunk_seconds: int = 10
|
||||||
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
) -> Generator[tuple[Path, ...], None, None]:
|
||||||
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
|
"""Extracts frames from a video and saves them to disk, yielding paths to the saved frames."""
|
||||||
|
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
@@ -92,56 +85,21 @@ class VideoMaker:
|
|||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
frames_per_chunk = int(fps * chunk_seconds)
|
frames_per_chunk = int(fps * chunk_seconds)
|
||||||
|
|
||||||
|
frame_index = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
paths = []
|
paths = []
|
||||||
|
|
||||||
for _ in range(frames_per_chunk):
|
for _ in range(frames_per_chunk):
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
cap.release()
|
cap.release()
|
||||||
return
|
return
|
||||||
paths.append(frame)
|
|
||||||
|
frame_path = output_dir / f"img_{frame_index:08d}.png"
|
||||||
|
cv2.imwrite(str(frame_path), frame)
|
||||||
|
|
||||||
|
paths.append(frame_path)
|
||||||
|
frame_index += 1
|
||||||
|
|
||||||
yield tuple(paths)
|
yield tuple(paths)
|
||||||
|
|
||||||
def images_to_video_pipeline(
|
|
||||||
self,
|
|
||||||
frames: Iterable[np.ndarray],
|
|
||||||
output_path: Path,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
fps: float,
|
|
||||||
):
|
|
||||||
pipeline = subprocess.Popen(
|
|
||||||
[
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-f", "rawvideo",
|
|
||||||
"-vcodec", "rawvideo",
|
|
||||||
"-pix_fmt", "bgr24",
|
|
||||||
"-s", f"{width}x{height}",
|
|
||||||
"-r", str(fps),
|
|
||||||
"-i", "-",
|
|
||||||
"-an",
|
|
||||||
"-vcodec", "libx264",
|
|
||||||
"-pix_fmt", "yuv420p",
|
|
||||||
str(output_path),
|
|
||||||
],
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stderr=subprocess.DEVNULL
|
|
||||||
)
|
|
||||||
if pipeline.stdin is None:
|
|
||||||
raise Exception("STDIN closed")
|
|
||||||
for frame in frames:
|
|
||||||
pipeline.stdin.write(frame.tobytes())
|
|
||||||
|
|
||||||
pipeline.stdin.close()
|
|
||||||
pipeline.wait()
|
|
||||||
|
|
||||||
def get_size(self, video_path: Path) -> tuple[int, int]:
|
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
|
||||||
if not cap.isOpened():
|
|
||||||
raise ValueError(f"Cannot open video: {video_path}")
|
|
||||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
||||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
return width, height
|
|
||||||
|
|||||||
Reference in New Issue
Block a user