Перенес networks внутрь src

This commit is contained in:
Viner Abubakirov
2026-04-01 21:41:05 +05:00
parent 829d0c8c59
commit 888cdb3151
18 changed files with 341 additions and 522 deletions

1
.gitignore vendored
View File

@@ -175,5 +175,6 @@ cython_debug/
.pypirc
.DS_Store
source/
output/

View File

@@ -4,8 +4,8 @@ from pathlib import Path
import torch
import numpy as np
from omegaconf import OmegaConf, DictConfig
from imageio import imread, imwrite
from src.utils import utils
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
@@ -84,9 +84,16 @@ class ImageInterpolator:
)
def interpolate(self, image1: Path, image2: Path, output_path: Path):
"""
Interpolates between two images and saves the result.
Args:
image1 (Path): Path to the first input image (only png and jpg formats are supported)
image2 (Path): Path to the second input image (only png and jpg formats are supported)
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(utils.read(image1)).to(self.device)
tensor2 = img2tensor(utils.read(image2)).to(self.device)
tensor1 = img2tensor(imread(image1)).to(self.device)
tensor2 = img2tensor(imread(image2)).to(self.device)
logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
)
@@ -115,7 +122,7 @@ class ImageInterpolator:
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
utils.write(output_path, tensor2img(interpolated.cpu()))
imwrite(output_path, tensor2img(interpolated.cpu()))
logging.debug(f"Saved interpolated image to: {output_path}")
def scale(self, height: int, width: int) -> float:

61
main.py
View File

@@ -1,11 +1,11 @@
import logging
import subprocess
from pathlib import Path
from typing import Generator
import cv2
from tqdm import tqdm
from time import perf_counter
from decimal import Decimal
from interpolator import get_device
from interpolator import ImageInterpolator
@@ -39,68 +39,11 @@ def move_images(src_dir: str, interpolated_dir: str, output_dir: str):
index += 1
def build_file_list(moved_dir: str, list_path: str):
import os
moved_dir = Path(moved_dir)
frames = sorted(moved_dir.glob("img_*.png"))
print(frames[0])
with open(list_path, "w") as f:
for frame in frames:
f.write(f"file '{os.path.abspath(frame)}'\n")
def build_ffmpeg_file_list(frames_dir: str, interpolated_dir: str, list_path: str):
frames = sorted(Path(frames_dir).glob("img_*.png"))
interps = sorted(Path(interpolated_dir).glob("img_*.png"))
if len(interps) != len(frames) - 1:
raise ValueError("Interpolated frames must be N-1")
with open(list_path, "w") as f:
for i in range(len(frames)):
f.write(f"file '{frames[i].resolve().as_posix()}'\n")
if i < len(interps):
f.write(f"file '{interps[i].resolve().as_posix()}'\n")
def merge_with_ffmpeg(
original_video: str,
file_list: str,
output_video: str,
):
cap = cv2.VideoCapture(original_video)
if not cap.isOpened():
raise ValueError("Cannot open original video")
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
new_fps = Decimal(fps * 2)
cmd = [
"ffmpeg",
"-y",
"-r", str(new_fps.quantize(Decimal("1.0000000000"))),
"-f", "concat",
"-safe", "0",
"-i", file_list,
"-c:v", "libx264rgb",
output_video,
]
print("Running ffmpeg command:", " ".join(cmd))
subprocess.run(cmd, check=True)
def video_frames_to_disk_generator(
video_path: str | Path,
output_dir: str | Path,
chunk_seconds: int = 10
):
) -> Generator[tuple[Path, ...], None, None]:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -1,69 +0,0 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from networks.blocks.ifrnet import (
convrelu, resize,
ResBlock,
)
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
mask=None, img_res=None, mean=None):
'''
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
'''
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = mask.reshape(b, num_flows, 1, h, w
).reshape(-1, 1, h, w) if mask is not None else None
img_res = img_res.reshape(b, num_flows, 3, h, w
).reshape(-1, 3, h, w) if img_res is not None else 0
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
) if mean is not None else 0
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: networks.AMT-G.Model
name: src.networks.AMT-G.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: networks.AMT-S.Model
name: src.networks.AMT-S.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -1,9 +1,11 @@
from typing import Optional
import torch
import torch.nn as nn
from networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from networks.blocks.feat_enc import LargeEncoder
from networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import LargeEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):
@@ -42,7 +44,7 @@ class Model(nn.Module):
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
)
def _get_updateblock(self, cdim, scale_factor=None):
def _get_updateblock(self, cdim: int, scale_factor: Optional[float] = None):
return BasicUpdateBlock(
cdim=cdim,
hidden_dim=192,
@@ -55,7 +57,15 @@ class Model(nn.Module):
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
def _corr_scale_lookup(
self,
corr_fn: BidirCorrBlock,
coord: torch.Tensor,
flow0: torch.Tensor,
flow1: torch.Tensor,
embt: torch.Tensor,
downsample: int = 1,
):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t1_scale = 1.0 / embt
@@ -70,7 +80,15 @@ class Model(nn.Module):
flow = torch.cat([flow0, flow1], dim=1)
return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
def forward(
self,
img0: torch.Tensor,
img1: torch.Tensor,
embt: torch.Tensor,
scale_factor: float = 1.0,
eval: bool = False,
**kwargs,
):
mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)

View File

@@ -1,38 +1,29 @@
import torch
import torch.nn as nn
from networks.blocks.raft import (
coords_grid,
BasicUpdateBlock, BidirCorrBlock
)
from networks.blocks.feat_enc import (
BasicEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import BasicEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):
def __init__(self,
def __init__(
self,
corr_radius=3,
corr_lvls=4,
num_flows=5,
channels=[48, 64, 72, 128],
skip_channels=48
skip_channels=48,
):
super(Model, self).__init__()
self.radius = corr_radius
self.corr_levels = corr_lvls
self.num_flows = num_flows
self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.)
self.feat_encoder = BasicEncoder(
output_dim=128, norm_fn="instance", dropout=0.0
)
self.encoder = Encoder([48, 64, 72, 128], large=True)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
@@ -45,22 +36,29 @@ class Model(nn.Module):
self.update2 = self._get_updateblock(48, 4.0)
self.comb_block = nn.Sequential(
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
nn.PReLU(6*self.num_flows),
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
nn.PReLU(6 * self.num_flows),
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
)
def _get_updateblock(self, cdim, scale_factor=None):
return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48,
corr_dim=256, corr_dim2=160, fc_dim=124,
scale_factor=scale_factor, corr_levels=self.corr_levels,
radius=self.radius)
return BasicUpdateBlock(
cdim=cdim,
hidden_dim=128,
flow_dim=48,
corr_dim=256,
corr_dim2=160,
fc_dim=124,
scale_factor=scale_factor,
corr_levels=self.corr_levels,
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t1_scale = 1. / embt
t0_scale = 1. / (1. - embt)
t1_scale = 1.0 / embt
t0_scale = 1.0 / (1.0 - embt)
if downsample != 1:
inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv)
@@ -72,7 +70,12 @@ class Model(nn.Module):
return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_
img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
@@ -81,7 +84,9 @@ class Model(nn.Module):
coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
corr_fn = BidirCorrBlock(
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
@@ -90,9 +95,9 @@ class Model(nn.Module):
######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
up_flow0_4, up_flow1_4,
embt, downsample=1)
corr_4, flow_4 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
)
# residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
@@ -102,10 +107,12 @@ class Model(nn.Module):
ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_3, up_flow1_3,
embt, downsample=2)
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
)
corr_3, flow_3 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
)
# residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
@@ -115,10 +122,12 @@ class Model(nn.Module):
ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_2, up_flow1_2,
embt, downsample=4)
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
)
corr_2, flow_2 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
)
# residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
@@ -128,28 +137,36 @@ class Model(nn.Module):
ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
)
if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
mask = resize(mask, scale_factor=(1.0/scale_factor))
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
mask, img_res, mean_)
imgt_pred = multi_flow_combine(
self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
)
imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval:
return { 'imgt_pred': imgt_pred, }
return {
"imgt_pred": imgt_pred,
}
else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return {
'imgt_pred': imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_],
"imgt_pred": imgt_pred,
"flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
"flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
"ft_pred": [ft_1_, ft_2_, ft_3_],
}

View File

@@ -1,31 +1,20 @@
import torch
import torch.nn as nn
from networks.blocks.raft import (
coords_grid,
SmallUpdateBlock, BidirCorrBlock
)
from networks.blocks.feat_enc import (
SmallEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import SmallEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):
def __init__(self,
def __init__(
self,
corr_radius=3,
corr_lvls=4,
num_flows=3,
channels=[20, 32, 44, 56],
skip_channels=20):
skip_channels=20,
):
super(Model, self).__init__()
self.radius = corr_radius
self.corr_levels = corr_lvls
@@ -33,7 +22,7 @@ class Model(nn.Module):
self.channels = channels
self.skip_channels = skip_channels
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.)
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn="instance", dropout=0.0)
self.encoder = Encoder(channels)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
@@ -46,21 +35,28 @@ class Model(nn.Module):
self.update2 = self._get_updateblock(20, 4)
self.comb_block = nn.Sequential(
nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1),
nn.PReLU(6*num_flows),
nn.Conv2d(6*num_flows, 3, 3, 1, 1),
nn.Conv2d(3 * num_flows, 6 * num_flows, 3, 1, 1),
nn.PReLU(6 * num_flows),
nn.Conv2d(6 * num_flows, 3, 3, 1, 1),
)
def _get_updateblock(self, cdim, scale_factor=None):
return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64,
fc_dim=68, scale_factor=scale_factor,
corr_levels=self.corr_levels, radius=self.radius)
return SmallUpdateBlock(
cdim=cdim,
hidden_dim=76,
flow_dim=20,
corr_dim=64,
fc_dim=68,
scale_factor=scale_factor,
corr_levels=self.corr_levels,
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t1_scale = 1. / embt
t0_scale = 1. / (1. - embt)
t1_scale = 1.0 / embt
t0_scale = 1.0 / (1.0 - embt)
if downsample != 1:
inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv)
@@ -72,7 +68,12 @@ class Model(nn.Module):
return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_
img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
@@ -81,7 +82,9 @@ class Model(nn.Module):
coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
corr_fn = BidirCorrBlock(
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
@@ -90,9 +93,9 @@ class Model(nn.Module):
######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
up_flow0_4, up_flow1_4,
embt, downsample=1)
corr_4, flow_4 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
)
# residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
@@ -102,10 +105,12 @@ class Model(nn.Module):
ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_3, up_flow1_3,
embt, downsample=2)
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
)
corr_3, flow_3 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
)
# residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
@@ -115,10 +120,12 @@ class Model(nn.Module):
ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_2, up_flow1_2,
embt, downsample=4)
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
)
corr_2, flow_2 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
)
# residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
@@ -128,27 +135,36 @@ class Model(nn.Module):
ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
)
if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
mask = resize(mask, scale_factor=(1.0/scale_factor))
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
mask, img_res, mean_)
imgt_pred = multi_flow_combine(
self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
)
imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval:
return { 'imgt_pred': imgt_pred, }
return {
"imgt_pred": imgt_pred,
}
else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return {
'imgt_pred': imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_],
"imgt_pred": imgt_pred,
"flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
"flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
"ft_pred": [ft_1_, ft_2_, ft_3_],
}

View File

@@ -1,30 +1,23 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from networks.blocks.ifrnet import (
convrelu, resize,
ResBlock,
)
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.pyramid1 = nn.Sequential(
convrelu(3, 32, 3, 2, 1),
convrelu(32, 32, 3, 1, 1)
convrelu(3, 32, 3, 2, 1), convrelu(32, 32, 3, 1, 1)
)
self.pyramid2 = nn.Sequential(
convrelu(32, 48, 3, 2, 1),
convrelu(48, 48, 3, 1, 1)
convrelu(32, 48, 3, 2, 1), convrelu(48, 48, 3, 1, 1)
)
self.pyramid3 = nn.Sequential(
convrelu(48, 72, 3, 2, 1),
convrelu(72, 72, 3, 1, 1)
convrelu(48, 72, 3, 2, 1), convrelu(72, 72, 3, 1, 1)
)
self.pyramid4 = nn.Sequential(
convrelu(72, 96, 3, 2, 1),
convrelu(96, 96, 3, 1, 1)
convrelu(72, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1)
)
def forward(self, img):
@@ -39,9 +32,9 @@ class Decoder4(nn.Module):
def __init__(self):
super(Decoder4, self).__init__()
self.convblock = nn.Sequential(
convrelu(192+1, 192),
convrelu(192 + 1, 192),
ResBlock(192, 32),
nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True)
nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True),
)
def forward(self, f0, f1, embt):
@@ -58,7 +51,7 @@ class Decoder3(nn.Module):
self.convblock = nn.Sequential(
convrelu(220, 216),
ResBlock(216, 32),
nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True)
nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -75,7 +68,7 @@ class Decoder2(nn.Module):
self.convblock = nn.Sequential(
convrelu(148, 144),
ResBlock(144, 32),
nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True)
nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -92,7 +85,7 @@ class Decoder1(nn.Module):
self.convblock = nn.Sequential(
convrelu(100, 96),
ResBlock(96, 32),
nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True)
nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -113,7 +106,12 @@ class Model(nn.Module):
self.decoder1 = Decoder1()
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_
img1 = img1 - mean_
@@ -145,10 +143,14 @@ class Model(nn.Module):
up_res_1 = out1[:, 5:]
if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor))
up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor))
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor))
up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor))
img0_warp = warp(img0, up_flow0_1)
img1_warp = warp(img1, up_flow1_1)
@@ -157,13 +159,15 @@ class Model(nn.Module):
imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval:
return { 'imgt_pred': imgt_pred, }
return {
"imgt_pred": imgt_pred,
}
else:
return {
'imgt_pred': imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_],
'img0_warp': img0_warp,
'img1_warp': img1_warp
"imgt_pred": imgt_pred,
"flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
"flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
"ft_pred": [ft_1_, ft_2_, ft_3_],
"img0_warp": img0_warp,
"img1_warp": img1_warp,
}

View File

@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
"""
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
"""
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch * 3 + 4, in_ch * 3),
ResBlock(in_ch * 3, skip_ch),
nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
return flow0, flow1, mask, img_res

View File

@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
import importlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from omegaconf import DictConfig

View File

@@ -1,199 +0,0 @@
import re
import sys
from pathlib import Path
import numpy as np
from imageio import imread, imwrite
def read(file: Path) -> np.ndarray:
readers = {
".float3": readFloat,
".flo": readFlow,
".ppm": readImage,
".pgm": readImage,
".png": readImage,
".jpg": readImage,
".pfm": lambda f: readPFM(f)[0],
}
func = readers.get(file.suffix.lower())
if func is None:
raise Exception("don't know how to read %s" % file)
return func(file)
def write(file: Path, data: np.ndarray) -> None:
writers = {
".float3": writeFloat,
".flo": writeFlow,
".ppm": writeImage,
".pgm": writeImage,
".png": writeImage,
".jpg": writeImage,
".pfm": writePFM,
}
func = writers.get(file.suffix.lower())
if func is None:
raise Exception("don't know how to write %s" % file)
return func(file, data)
def readPFM(file: Path):
data = open(file, "rb")
color = None
width = None
height = None
scale = None
endian = None
header = data.readline().rstrip()
if header.decode("ascii") == "PF":
color = True
elif header.decode("ascii") == "Pf":
color = False
else:
raise Exception("Not a PFM file.")
dim_match = re.match(r"^(\d+)\s(\d+)\s$", data.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception("Malformed PFM header.")
scale = float(data.readline().decode("ascii").rstrip())
if scale < 0:
endian = "<"
scale = -scale
else:
endian = ">"
result = np.fromfile(data, endian + "f")
shape = (height, width, 3) if color else (height, width)
result = np.reshape(result, shape)
result = np.flipud(result)
return result, scale
def writePFM(file: Path, image: np.ndarray, scale=1):
data = open(file, "wb")
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3:
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
data.write("PF\n" if color else "Pf\n".encode()) # type: ignore
data.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
data.write("%f\n".encode() % scale)
image.tofile(data)
def readFlow(file: Path):
if file.suffix.lower() == ".pfm":
return readPFM(file)[0][:, :, 0:2]
f = open(file, "rb")
header = f.read(4)
if header.decode("utf-8") != "PIEH":
raise Exception("Flow file header does not contain PIEH")
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
return flow.astype(np.float32)
def readImage(file: Path):
if file.suffix.lower() == ".pfm":
data = readPFM(file)[0]
if len(data.shape) == 3:
return data[:, :, 0:3]
else:
return data
return imread(file)
def writeImage(file: Path, data: np.ndarray):
if file.suffix.lower() == ".pfm":
return writePFM(file, data, 1)
return imwrite(file, data)
def writeFlow(file: Path, flow: np.ndarray):
f = open(file, "wb")
f.write("PIEH".encode("utf-8"))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def readFloat(file: Path):
f = open(file, "rb")
if (f.readline().decode("utf-8")) != "float\n":
raise Exception("float file %s did not contain <float> keyword" % file)
dim = int(f.readline())
dims = []
count = 1
for _ in range(0, dim):
d = int(f.readline())
dims.append(d)
count *= d
dims = list(reversed(dims))
data = np.fromfile(f, np.float32, count).reshape(dims)
if dim > 2:
data = np.transpose(data, (2, 1, 0))
data = np.transpose(data, (1, 0, 2))
return data
def writeFloat(file: Path, data: np.ndarray):
f = open(file, "wb")
dim = len(data.shape)
if dim > 3:
raise Exception("bad float file dimension: %d" % dim)
f.write(("float\n").encode("ascii"))
f.write(("%d\n" % dim).encode("ascii"))
if dim == 1:
f.write(("%d\n" % data.shape[0]).encode("ascii"))
else:
f.write(("%d\n" % data.shape[1]).encode("ascii"))
f.write(("%d\n" % data.shape[0]).encode("ascii"))
for i in range(2, dim):
f.write(("%d\n" % data.shape[i]).encode("ascii"))
data = data.astype(np.float32)
if dim == 2:
data.tofile(f)
else:
np.transpose(data, (2, 0, 1)).tofile(f)