6 Commits

Author SHA1 Message Date
Viner Abubakirov
1615cbc60d Попытка оптимизировать модель для более быстрого расчёта 2026-04-19 11:57:11 +05:00
Viner Abubakirov
c7acd66974 переименовал runner на run 2026-04-04 22:06:27 +05:00
Viner Abubakirov
2d67b72128 Перевел импорты модулей в относительные пути 2026-04-04 11:57:41 +05:00
c91cf6b53a Merge pull request 'dev' (#2) from dev into main
Reviewed-on: #2
2026-04-03 18:28:31 +05:00
Viner Abubakirov
c72e34f9dc checkout presets.py from dev 2026-04-02 18:31:54 +05:00
359f20c3c4 Merge pull request 'dev' (#1) from dev into main
Reviewed-on: #1
2026-04-02 12:17:05 +05:00
24 changed files with 612 additions and 2838 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

192
main.py
View File

@@ -1,195 +1,7 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from src.runner import run
from cv2 import imwrite
import tqdm
from src.config import presets from src.config import presets
from src.utils.fs import FileSystem
from src.utils.video import VideoMaker
from src.interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(preset: presets.Preset, device: "torch.device") -> ModelRunner:
return ModelRunner(preset, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
preset: presets.Preset,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(preset, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 1
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
part += 1
for i in tqdm.tqdm(
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
preset=preset,
base_path=base_path,
)
pipeline.run(video_path, output_video)
def main(): def main():
@@ -218,7 +30,7 @@ def main():
default="global", default="global",
) )
args = parser.parse_args() args = parser.parse_args()
runner( run(
base_path=Path(args.base_path), base_path=Path(args.base_path),
video_path=Path(args.video_path), video_path=Path(args.video_path),
output_video=args.output, output_video=args.output,

View File

@@ -1,8 +0,0 @@
import torch
from src.export_to_onnx import export_to_onnx
from src.config.presets import SMALL
if __name__ == "__main__":
device = torch.device("cuda")
export_to_onnx(SMALL, "src/pretrained/amt_s.onnx", device)

View File

@@ -5,16 +5,9 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"imageio>=2.37.3",
"numpy>=2.4.4", "numpy>=2.4.4",
"nvidia-modelopt[all]>=0.33.1",
"omegaconf>=2.3.0", "omegaconf>=2.3.0",
"onnx>=1.21.0",
"onnxscript>=0.6.2",
"opencv-python>=4.13.0.92", "opencv-python>=4.13.0.92",
"tensorrt>=10.16.1.11", "torch>=2.11.0",
"torch==2.5.1",
"torch-tensorrt>=2.5.0",
"torchvision>=0.20.1",
"tqdm>=4.67.3", "tqdm>=4.67.3",
] ]

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: src.networks.AMT-G.Model name: AMT-G.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: src.networks.AMT-L.Model name: AMT-L.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: src.networks.AMT-S.Model name: AMT-S.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

View File

@@ -1,4 +1,3 @@
from typing import Literal
from pathlib import Path from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
@@ -7,7 +6,6 @@ from dataclasses import dataclass
class Preset: class Preset:
config: Path config: Path
checkpoint: Path checkpoint: Path
onnx: Path | None = None
SMALL = Preset( SMALL = Preset(

View File

@@ -1,29 +0,0 @@
import torch
import torchvision
torchvision.disable_beta_transforms_warning()
torch.backends.cudnn.enabled = False
from .interpolator import ModelRunner
from .config.presets import Preset
def export_to_onnx(preset: Preset, output_path: str, device: torch.device):
model_runner = ModelRunner(preset, device)
# model_runner.model.eval()
dummy_input = model_runner.get_dummy_input()
torch.onnx.export(
model_runner.model,
dummy_input,
output_path,
opset_version=17,
input_names=['img0', 'img1', 'embt'],
output_names=["imgt_pred"],
dynamic_axes={
"img0": {0: "batch", 2: "height", 3: "width"},
"img1": {0: "batch", 2: "height", 3: "width"},
"embt": {0: "batch", 2: "height", 3: "width"},
"imgt_pred": {0: "batch", 2: "height", 3: "width"},
},
dynamo=True,
use_external_data_format=False,
)

View File

@@ -1,16 +1,13 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional
from cv2 import imread
import torch import torch
import onnxruntime as ort
import numpy as np import numpy as np
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
from src.config.presets import Preset from .utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.torch import img2tensor, tensor2img from .utils.build import build_from_cfg
from src.utils.build import build_from_cfg from .utils.padder import InputPadder
class Anchor: class Anchor:
@@ -23,27 +20,8 @@ class Anchor:
return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})" return f"Anchor(resolution={self.resolution}, memory={self.memory}, memory_bias={self.memory_bias})"
class ONNXWrapper:
def __init__(self, path):
self.session = ort.InferenceSession(path)
self.input_names = [i.name for i in self.session.get_inputs()]
self.output_names = [o.name for o in self.session.get_outputs()]
def __call__(self, tensor1, tensor2, embt):
inputs = {
self.input_names[0]: tensor1.cpu().numpy(),
self.input_names[1]: tensor2.cpu().numpy(),
self.input_names[2]: embt.cpu().numpy(),
}
outputs = self.session.run(self.output_names, inputs)
return {"imgt_pred": torch.from_numpy(outputs[0])}
class ModelRunner: class ModelRunner:
def __init__(self, preset: Preset, device: torch.device) -> None: def __init__(self, config: Path, ckpt_path: Path, device: torch.device) -> None:
"""Initializes the ModelRunner with configuration and checkpoint. """Initializes the ModelRunner with configuration and checkpoint.
Args: Args:
@@ -51,73 +29,18 @@ class ModelRunner:
ckpt_path (Path): Path to model checkpoint in .pth format ckpt_path (Path): Path to model checkpoint in .pth format
device (torch.device): Device to load the model on device (torch.device): Device to load the model on
""" """
self.model: Optional[torch.nn.Module] = None torch.set_float32_matmul_precision("high")
self.session: Optional[ONNXWrapper] = None omega_config = OmegaConf.load(config)
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
if preset.onnx:
self.session = ONNXWrapper(preset.onnx)
self.device = device
self.embt = self.embt.cpu().numpy()
return
omega_config = OmegaConf.load(preset.config)
network_config: DictConfig = omega_config.network network_config: DictConfig = omega_config.network
logging.info( logging.info(
f"Loaded network configuration: {network_config} from [{preset.checkpoint}]" f"Loaded network configuration: {network_config} from [{ckpt_path}]"
) )
model = build_from_cfg(network_config) model = build_from_cfg(network_config)
checkpoint = torch.load( checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
preset.checkpoint, map_location=device, weights_only=False
)
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
model = model.to(device) model = model.to(get_device())
model.eval() model.eval()
# self.model = torch.compile(model) self.model = torch.compile(model, mode="max-autotune")
self.model = model
self.device = device
# self.model = torch.compile(model, backend="tensorrt")
if logging.getLogger().isEnabledFor(logging.DEBUG):
for name, param in self.model.named_parameters():
logging.debug(
f"Parameter: {name}, shape: {param.shape}, dtype: {param.dtype}"
)
def get_dummy_input(self):
"""Generates a dummy input tensor for ONNX export."""
return (
img2tensor(imread(filename="example/frame_01.png"), self.device),
img2tensor(imread(filename="example/frame_02.png"), self.device),
self.embt,
)
def run(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""Runs the model inference to interpolate between two images.
Args:
image1 (np.ndarray): First input image as a NumPy array
image2 (np.ndarray): Second input image as a NumPy array
Returns:
np.ndarray: Interpolated image as a NumPy array
"""
if self.session:
image1 = img2tensor(image1, self.device).cpu().numpy()
image2 = img2tensor(image2, self.device).cpu().numpy()
inputs = {
"img0": image1,
"img1": image2,
"embt": self.embt,
}
outputs = self.session.session.run(None, inputs)
return outputs[0]
tensor1 = img2tensor(image1, self.device)
tensor2 = img2tensor(image2, self.device)
with torch.no_grad():
with torch.amp.autocast(self.device.type):
interpolated = self.model(tensor1, tensor2, self.embt)["imgt_pred"]
return tensor2img(interpolated.cpu())
def get_vram_available(device: torch.device) -> int: def get_vram_available(device: torch.device) -> int:
@@ -154,22 +77,33 @@ class ImageInterpolator:
self.device = device self.device = device
self.anchor = anchor self.anchor = anchor
self.vram_available = get_vram_available(device) self.vram_available = get_vram_available(device)
self._scale = None
self._padder = None
self.embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
self.model_runner = model_runner self.model_runner = model_runner
logging.debug( logging.debug(
f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes" f"Initialized ImageInterpolator with device: {device}, anchor: {anchor}, available VRAM: {self.vram_available} bytes"
) )
def interpolate(self, image1: np.ndarray, image2: np.ndarray) -> np.ndarray: def interpolate(self, image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
""" interpolated = self.model_runner.model(
Interpolates between two images and saves the result. image1, image2, self.embt, scale_factor=self._scale, eval=True
Args: )["imgt_pred"]
image1 (Path): Path to the first input image (only png and jpg formats are supported) if not self._padder:
image2 (Path): Path to the second input image (only png and jpg formats are supported) raise NotImplemented("Padder not implemented")
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported) return self._padder.unpad(interpolated)[0]
"""
return self.model_runner.run(image1, image2) def make_tensor(self, img: np.ndarray) -> torch.Tensor:
tensor = img2tensor(img).to(self.device)
h, w = tensor.shape[2], tensor.shape[3]
scale = self.scale(h, w)
padding = int(16 / scale)
if self._padder is None:
self._padder = InputPadder(tensor.shape, padding)
return self._padder.pad(tensor)[0]
def scale(self, height: int, width: int) -> float: def scale(self, height: int, width: int) -> float:
if self._scale is None:
scale = ( scale = (
self.anchor.resolution self.anchor.resolution
/ (height * width) / (height * width)
@@ -180,7 +114,9 @@ class ImageInterpolator:
scale = 1 if scale > 1 else scale scale = 1 if scale > 1 else scale
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
if scale < 1: if scale < 1:
logging.info( logging.debug(
f"Due to the limited VRAM, the video will be scaled by {scale:.2f}" f"Due to the limited VRAM, the video will be scaled by {scale:.2f}"
) )
return scale self._scale = float(scale)
logging.info(f"Calculated scale factor: {self._scale:.2f}")
return self._scale

View File

@@ -2,10 +2,10 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import LargeEncoder from .blocks.feat_enc import LargeEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module): class Model(nn.Module):
@@ -177,14 +177,11 @@ class Model(nn.Module):
) )
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( factor = 1.0 / scale_factor
1.0 / scale_factor up_flow0_1 = resize(up_flow0_1, factor) * factor
) up_flow1_1 = resize(up_flow1_1, factor) * factor
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( mask = resize(mask, factor)
1.0 / scale_factor img_res = resize(img_res, factor)
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions # Merge multiple predictions
imgt_pred = multi_flow_combine( imgt_pred = multi_flow_combine(

View File

@@ -1,10 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import BasicEncoder from .blocks.feat_enc import BasicEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module): class Model(nn.Module):
@@ -142,14 +142,11 @@ class Model(nn.Module):
) )
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( factor = 1.0 / scale_factor
1.0 / scale_factor up_flow0_1 = resize(up_flow0_1, factor) * factor
) up_flow1_1 = resize(up_flow1_1, factor) * factor
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( mask = resize(mask, factor)
1.0 / scale_factor img_res = resize(img_res, factor)
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions # Merge multiple predictions
imgt_pred = multi_flow_combine( imgt_pred = multi_flow_combine(

View File

@@ -1,9 +1,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock from .blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import SmallEncoder from .blocks.feat_enc import SmallEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module): class Model(nn.Module):
@@ -67,14 +67,7 @@ class Model(nn.Module):
flow = torch.cat([flow0, flow1], dim=1) flow = torch.cat([flow0, flow1], dim=1)
return corr, flow return corr, flow
def forward( def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
self,
img0: torch.Tensor,
img1: torch.Tensor,
embt: torch.Tensor,
):
scale_factor = 1.0
eval = False
mean_ = ( mean_ = (
torch.cat([img0, img1], 2) torch.cat([img0, img1], 2)
.mean(1, keepdim=True) .mean(1, keepdim=True)
@@ -147,14 +140,11 @@ class Model(nn.Module):
) )
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( factor = 1.0 / scale_factor
1.0 / scale_factor up_flow0_1 = resize(up_flow0_1, factor) * factor
) up_flow1_1 = resize(up_flow1_1, factor) * factor
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( mask = resize(mask, factor)
1.0 / scale_factor img_res = resize(img_res, factor)
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions # Merge multiple predictions
imgt_pred = multi_flow_combine( imgt_pred = multi_flow_combine(

View File

@@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.utils.flow_utils import warp from ..utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock from .blocks.ifrnet import convrelu, resize, ResBlock
class Encoder(nn.Module): class Encoder(nn.Module):

View File

@@ -1,44 +1,40 @@
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
class BottleneckBlock(nn.Module): class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1): def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8 num_groups = planes // 8
if norm_fn == "group": if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1: if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch": elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4) self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4) self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes) self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1: if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes) self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == "instance": elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4) self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4) self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes) self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1: if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes) self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == "none": elif norm_fn == 'none':
self.norm1 = nn.Sequential() self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential() self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential() self.norm3 = nn.Sequential()
@@ -50,8 +46,8 @@ class BottleneckBlock(nn.Module):
else: else:
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
)
def forward(self, x): def forward(self, x):
y = x y = x
@@ -66,36 +62,34 @@ class BottleneckBlock(nn.Module):
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1): def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8 num_groups = planes // 8
if norm_fn == "group": if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1: if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch": elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes) self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1: if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes) self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance": elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes) self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1: if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes) self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none": elif norm_fn == 'none':
self.norm1 = nn.Sequential() self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential() self.norm2 = nn.Sequential()
if not stride == 1: if not stride == 1:
@@ -106,8 +100,8 @@ class ResidualBlock(nn.Module):
else: else:
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
)
def forward(self, x): def forward(self, x):
y = x y = x
@@ -121,20 +115,20 @@ class ResidualBlock(nn.Module):
class SmallEncoder(nn.Module): class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__() super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn self.norm_fn = norm_fn
if self.norm_fn == "group": if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == "batch": elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32) self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == "instance": elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32) self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == "none": elif self.norm_fn == 'none':
self.norm1 = nn.Sequential() self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
@@ -153,7 +147,7 @@ class SmallEncoder(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None: if m.weight is not None:
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
@@ -168,15 +162,14 @@ class SmallEncoder(nn.Module):
self.in_planes = dim self.in_planes = dim
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(
self, x: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...] def forward(self, x):
):
# if input is list, combine batch dimension # if input is list, combine batch dimension
batch_dim = None is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list := isinstance(x, tuple) or isinstance(x, list): if is_list:
batch_dim = x[0].shape[0] batch_dim = x[0].shape[0]
x: torch.Tensor = torch.cat(x, dim=0) x = torch.cat(x, dim=0)
x = self.conv1(x) x = self.conv1(x)
x = self.norm1(x) x = self.norm1(x)
@@ -190,30 +183,26 @@ class SmallEncoder(nn.Module):
if self.training and self.dropout is not None: if self.training and self.dropout is not None:
x = self.dropout(x) x = self.dropout(x)
if is_list and batch_dim is not None: if is_list:
return torch.split(x, [batch_dim, batch_dim], dim=0) x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x return x
def __call__(self, *args: Any, **kwds: Any) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return super().__call__(*args, **kwds)
class BasicEncoder(nn.Module): class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__() super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn self.norm_fn = norm_fn
if self.norm_fn == "group": if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == "batch": elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64) self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == "instance": elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64) self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == "none": elif self.norm_fn == 'none':
self.norm1 = nn.Sequential() self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
@@ -233,7 +222,7 @@ class BasicEncoder(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None: if m.weight is not None:
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
@@ -248,6 +237,7 @@ class BasicEncoder(nn.Module):
self.in_planes = dim self.in_planes = dim
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
# if input is list, combine batch dimension # if input is list, combine batch dimension
@@ -274,22 +264,21 @@ class BasicEncoder(nn.Module):
return x return x
class LargeEncoder(nn.Module): class LargeEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(LargeEncoder, self).__init__() super(LargeEncoder, self).__init__()
self.norm_fn = norm_fn self.norm_fn = norm_fn
if self.norm_fn == "group": if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == "batch": elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64) self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == "instance": elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64) self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == "none": elif self.norm_fn == 'none':
self.norm1 = nn.Sequential() self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
@@ -310,7 +299,7 @@ class LargeEncoder(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None: if m.weight is not None:
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
@@ -325,6 +314,7 @@ class LargeEncoder(nn.Module):
self.in_planes = dim self.in_planes = dim
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
# if input is list, combine batch dimension # if input is list, combine batch dimension

View File

@@ -1,81 +1,39 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from src.utils.flow_utils import warp from ...utils.flow_utils import warp
def resize(x: torch.Tensor, scale_factor: float) -> torch.Tensor: def resize(x, scale_factor):
return F.interpolate( return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
def convrelu(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
):
return nn.Sequential( return nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
in_channels, nn.PReLU(out_channels)
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=bias,
),
nn.PReLU(out_channels),
) )
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__(self, in_channels, side_channels, bias=True): def __init__(self, in_channels, side_channels, bias=True):
super(ResBlock, self).__init__() super(ResBlock, self).__init__()
self.side_channels = side_channels self.side_channels = side_channels
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias nn.PReLU(in_channels)
),
nn.PReLU(in_channels),
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d( nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
side_channels, nn.PReLU(side_channels)
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
) )
self.conv3 = nn.Sequential( self.conv3 = nn.Sequential(
nn.Conv2d( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias nn.PReLU(in_channels)
),
nn.PReLU(in_channels),
) )
self.conv4 = nn.Sequential( self.conv4 = nn.Sequential(
nn.Conv2d( nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
side_channels, nn.PReLU(side_channels)
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv5 = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
) )
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
self.prelu = nn.PReLU(in_channels) self.prelu = nn.PReLU(in_channels)
def forward(self, x): def forward(self, x):
@@ -94,7 +52,6 @@ class ResBlock(nn.Module):
out = self.prelu(x + out) out = self.prelu(x + out)
return out return out
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, channels, large=False): def __init__(self, channels, large=False):
super(Encoder, self).__init__() super(Encoder, self).__init__()
@@ -103,32 +60,29 @@ class Encoder(nn.Module):
for idx, ch in enumerate(channels, 1): for idx, ch in enumerate(channels, 1):
k = 7 if large and idx == 1 else 3 k = 7 if large and idx == 1 else 3
p = 3 if k ==7 else 1 p = 3 if k ==7 else 1
self.register_module( self.register_module(f'pyramid{idx}',
f"pyramid{idx}",
nn.Sequential( nn.Sequential(
convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1) convrelu(prev_ch, ch, k, 2, p),
), convrelu(ch, ch, 3, 1, 1)
) ))
prev_ch = ch prev_ch = ch
def forward(self, in_x): def forward(self, in_x):
fs = [] fs = []
for idx in range(len(self.channels)): for idx in range(len(self.channels)):
out_x = getattr(self, f"pyramid{idx + 1}")(in_x) out_x = getattr(self, f'pyramid{idx+1}')(in_x)
fs.append(out_x) fs.append(out_x)
in_x = out_x in_x = out_x
return fs return fs
class InitDecoder(nn.Module): class InitDecoder(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch) -> None: def __init__(self, in_ch, out_ch, skip_ch) -> None:
super().__init__() super().__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(in_ch*2+1, in_ch*2), convrelu(in_ch*2+1, in_ch*2),
ResBlock(in_ch*2, skip_ch), ResBlock(in_ch*2, skip_ch),
nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True), nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
) )
def forward(self, f0, f1, embt): def forward(self, f0, f1, embt):
h, w = f0.shape[2:] h, w = f0.shape[2:]
embt = embt.repeat(1, 1, h, w) embt = embt.repeat(1, 1, h, w)
@@ -137,16 +91,14 @@ class InitDecoder(nn.Module):
ft_ = out[:, 4:, ...] ft_ = out[:, 4:, ...]
return flow0, flow1, ft_ return flow0, flow1, ft_
class IntermediateDecoder(nn.Module): class IntermediateDecoder(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch) -> None: def __init__(self, in_ch, out_ch, skip_ch) -> None:
super().__init__() super().__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3), convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch), ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1, bias=True), nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
) )
def forward(self, ft_, f0, f1, flow0_in, flow1_in): def forward(self, ft_, f0, f1, flow0_in, flow1_in):
f0_warp = warp(f0, flow0_in) f0_warp = warp(f0, flow0_in)
f1_warp = warp(f1, flow1_in) f1_warp = warp(f1, flow1_in)

View File

@@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.utils.flow_utils import warp from ...utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock from .ifrnet import convrelu, resize, ResBlock
def multi_flow_combine( def multi_flow_combine(

View File

@@ -4,12 +4,10 @@ import torch.nn.functional as F
def resize(x, scale_factor): def resize(x, scale_factor):
return F.interpolate( return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def bilinear_sampler(img: torch.Tensor, coords: torch.Tensor, mask=False): def bilinear_sampler(img, coords, mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """ """ Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:] H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1) xgrid, ygrid = coords.split([1,1], dim=-1)
@@ -27,25 +25,16 @@ def bilinear_sampler(img: torch.Tensor, coords: torch.Tensor, mask=False):
def coords_grid(batch, ht, wd, device): def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid( coords = torch.meshgrid(torch.arange(ht, device=device),
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij" torch.arange(wd, device=device),
) indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float() coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1) return coords[None].repeat(batch, 1, 1, 1)
class SmallUpdateBlock(nn.Module): class SmallUpdateBlock(nn.Module):
def __init__( def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
self, corr_levels=4, radius=3, scale_factor=None):
cdim,
hidden_dim,
flow_dim,
corr_dim,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
):
super(SmallUpdateBlock, self).__init__() super(SmallUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) **2 cor_planes = corr_levels * (2 * radius + 1) **2
self.scale_factor = scale_factor self.scale_factor = scale_factor
@@ -76,9 +65,8 @@ class SmallUpdateBlock(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr): def forward(self, net, flow, corr):
net = ( net = resize(net, 1 / self.scale_factor
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net ) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr)) cor = self.lrelu(self.convc1(corr))
flo = self.lrelu(self.convf1(flow)) flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo)) flo = self.lrelu(self.convf2(flo))
@@ -92,27 +80,14 @@ class SmallUpdateBlock(nn.Module):
if self.scale_factor is not None: if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor) delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize( delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow return delta_net, delta_flow
class BasicUpdateBlock(nn.Module): class BasicUpdateBlock(nn.Module):
def __init__( def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
self, fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
cdim,
hidden_dim,
flow_dim,
corr_dim,
corr_dim2,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
out_num=1,
):
super(BasicUpdateBlock, self).__init__() super(BasicUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) **2 cor_planes = corr_levels * (2 * radius + 1) **2
@@ -144,9 +119,8 @@ class BasicUpdateBlock(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr): def forward(self, net, flow, corr):
net = ( net = resize(net, 1 / self.scale_factor
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net ) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr)) cor = self.lrelu(self.convc1(corr))
cor = self.lrelu(self.convc2(cor)) cor = self.lrelu(self.convc2(cor))
flo = self.lrelu(self.convf1(flow)) flo = self.lrelu(self.convf1(flow))
@@ -161,20 +135,16 @@ class BasicUpdateBlock(nn.Module):
if self.scale_factor is not None: if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor) delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize( delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow return delta_net, delta_flow
class BidirCorrBlock: class BidirCorrBlock:
def __init__( def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self, fmap1: torch.Tensor, fmap2: torch.Tensor, num_levels=4, radius=4
):
self.num_levels = num_levels self.num_levels = num_levels
self.radius = radius self.radius = radius
self.corr_pyramid: list[torch.Tensor] = [] self.corr_pyramid = []
self.corr_pyramid_T: list[torch.Tensor] = [] self.corr_pyramid_T = []
corr = BidirCorrBlock.corr(fmap1, fmap2) corr = BidirCorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape batch, h1, w1, dim, h2, w2 = corr.shape
@@ -192,13 +162,11 @@ class BidirCorrBlock:
self.corr_pyramid.append(corr) self.corr_pyramid.append(corr)
self.corr_pyramid_T.append(corr_T) self.corr_pyramid_T.append(corr_T)
def __call__(self, coords0: torch.Tensor, coords1: torch.Tensor): def __call__(self, coords0, coords1):
r = self.radius r = self.radius
coords0 = coords0.permute(0, 2, 3, 1) coords0 = coords0.permute(0, 2, 3, 1)
coords1 = coords1.permute(0, 2, 3, 1) coords1 = coords1.permute(0, 2, 3, 1)
assert coords0.shape == coords1.shape, ( assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
)
batch, h1, w1, _ = coords0.shape batch, h1, w1, _ = coords0.shape
out_pyramid = [] out_pyramid = []
@@ -209,13 +177,13 @@ class BidirCorrBlock:
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
delta: torch.Tensor = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
delta_lvl: torch.Tensor = delta.view(1, 2 * r + 1, 2 * r + 1, 2) delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
centroid_lvl_0: torch.Tensor = coords0.reshape(batch * h1 * w1, 1, 1, 2) / 2**i centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
centroid_lvl_1: torch.Tensor = coords1.reshape(batch * h1 * w1, 1, 1, 2) / 2**i centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
coords_lvl_0: torch.Tensor = centroid_lvl_0 + delta_lvl coords_lvl_0 = centroid_lvl_0 + delta_lvl
coords_lvl_1: torch.Tensor = centroid_lvl_1 + delta_lvl coords_lvl_1 = centroid_lvl_1 + delta_lvl
corr = bilinear_sampler(corr, coords_lvl_0) corr = bilinear_sampler(corr, coords_lvl_0)
corr_T = bilinear_sampler(corr_T, coords_lvl_1) corr_T = bilinear_sampler(corr_T, coords_lvl_1)
@@ -226,16 +194,14 @@ class BidirCorrBlock:
out = torch.cat(out_pyramid, dim=-1) out = torch.cat(out_pyramid, dim=-1)
out_T = torch.cat(out_pyramid_T, dim=-1) out_T = torch.cat(out_pyramid_T, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute( return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
0, 3, 1, 2
).contiguous().float()
@staticmethod @staticmethod
def corr(fmap1: torch.Tensor, fmap2: torch.Tensor): def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd) fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd) fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2) corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd) corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr * (dim**-0.5) return corr / torch.sqrt(torch.tensor(dim).float())

179
src/runner.py Normal file
View File

@@ -0,0 +1,179 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
import torch
from .config import presets
from .utils.fs import FileSystem
from .utils.video import VideoMaker
from .utils.torch import tensor2img
from .interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
# return Anchor(
# resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
# )
return Anchor(
resolution=1280 * 720, memory=6500 * 1024**2, memory_bias=7500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames: tuple["np.ndarray", ...] = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
with torch.autocast(self.device.type, torch.float16):
with torch.no_grad():
prev_tensor = None
for idx, frames in enumerate(
self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
)
):
interpolated_frames: list["np.ndarray"] = []
for frame in tqdm.tqdm(frames):
tensor = self.interpolator.make_tensor(frame)
if prev_tensor is None:
prev_tensor = tensor
continue
interpolated_frames.append(
tensor2img(
self.interpolator.interpolate(
prev_tensor, tensor
)
)
)
prev_tensor = tensor
generator = self._frame_generator(frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{idx:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
self._merge_video_parts(self.fs.output_path / output_video)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
if len(source) == len(interpolated):
first = interpolated
second = source
else:
first = source
second = interpolated
for i, frame in enumerate(first):
yield frame
if i < len(second):
yield second[i]
def run(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)

View File

@@ -1,16 +1,19 @@
import importlib
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ..networks import AMT_G, AMT_L, AMT_S
if TYPE_CHECKING: if TYPE_CHECKING:
from omegaconf import DictConfig from omegaconf import DictConfig
def base_build_fn(module: str, cls: str, params: dict):
return getattr(importlib.import_module(module, package=None), cls)(**params)
def build_from_cfg(config: "DictConfig"): def build_from_cfg(config: "DictConfig"):
packages = {
"AMT-G": AMT_G,
"AMT-L": AMT_L,
"AMT-S": AMT_S
}
module, cls = config["name"].rsplit(".", 1) module, cls = config["name"].rsplit(".", 1)
params: dict = config.get("params", {}) params: dict = config.get("params", {})
return base_build_fn(module, cls, params) return getattr(packages[module], cls)(**params)

View File

@@ -21,9 +21,6 @@ class InputPadder:
] ]
def pad(self, *inputs: "torch.Tensor"): def pad(self, *inputs: "torch.Tensor"):
if len(inputs) == 1:
return F.pad(inputs[0], self._pad, mode="replicate")
else:
return [F.pad(x, self._pad, mode="replicate") for x in inputs] return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, *inputs: "torch.Tensor"): def unpad(self, *inputs: "torch.Tensor"):

View File

@@ -5,26 +5,23 @@ import numpy as np
def tensor2img(tensor: torch.Tensor): def tensor2img(tensor: torch.Tensor):
tensor = ( return (
tensor.mul(255.0) (tensor * 255.0)
.clamp_(0, 255) .detach()
.to(torch.uint8)
.squeeze(0) .squeeze(0)
.permute(1, 2, 0) .permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
) )
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor") logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3: if img.shape[-1] > 3:
img = img[:, :, :3] img = img[:, :, :3]
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:

2320
uv.lock generated

File diff suppressed because it is too large Load Diff