Init code
This commit is contained in:
0
networks/blocks/__init__.py
Executable file
0
networks/blocks/__init__.py
Executable file
343
networks/blocks/feat_enc.py
Executable file
343
networks/blocks/feat_enc.py
Executable file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(72, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
class LargeEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(LargeEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(112, stride=2)
|
||||
self.layer3 = self._make_layer(160, stride=2)
|
||||
self.layer3_2 = self._make_layer(160, stride=1)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer3_2(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
111
networks/blocks/ifrnet.py
Executable file
111
networks/blocks/ifrnet.py
Executable file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from src.utils.flow_utils import warp
|
||||
|
||||
|
||||
def resize(x, scale_factor):
|
||||
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
||||
|
||||
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
|
||||
nn.PReLU(out_channels)
|
||||
)
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, side_channels, bias=True):
|
||||
super(ResBlock, self).__init__()
|
||||
self.side_channels = side_channels
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(in_channels)
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(side_channels)
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(in_channels)
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
||||
nn.PReLU(side_channels)
|
||||
)
|
||||
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
||||
self.prelu = nn.PReLU(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
||||
res_feat = out[:, :-self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels:, :, :]
|
||||
side_feat = self.conv2(side_feat)
|
||||
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
res_feat = out[:, :-self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels:, :, :]
|
||||
side_feat = self.conv4(side_feat)
|
||||
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
out = self.prelu(x + out)
|
||||
return out
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, channels, large=False):
|
||||
super(Encoder, self).__init__()
|
||||
self.channels = channels
|
||||
prev_ch = 3
|
||||
for idx, ch in enumerate(channels, 1):
|
||||
k = 7 if large and idx == 1 else 3
|
||||
p = 3 if k ==7 else 1
|
||||
self.register_module(f'pyramid{idx}',
|
||||
nn.Sequential(
|
||||
convrelu(prev_ch, ch, k, 2, p),
|
||||
convrelu(ch, ch, 3, 1, 1)
|
||||
))
|
||||
prev_ch = ch
|
||||
|
||||
def forward(self, in_x):
|
||||
fs = []
|
||||
for idx in range(len(self.channels)):
|
||||
out_x = getattr(self, f'pyramid{idx+1}')(in_x)
|
||||
fs.append(out_x)
|
||||
in_x = out_x
|
||||
return fs
|
||||
|
||||
class InitDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||
super().__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch*2+1, in_ch*2),
|
||||
ResBlock(in_ch*2, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
|
||||
)
|
||||
def forward(self, f0, f1, embt):
|
||||
h, w = f0.shape[2:]
|
||||
embt = embt.repeat(1, 1, h, w)
|
||||
out = self.convblock(torch.cat([f0, f1, embt], 1))
|
||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||
ft_ = out[:, 4:, ...]
|
||||
return flow0, flow1, ft_
|
||||
|
||||
class IntermediateDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
||||
super().__init__()
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch*3+4, in_ch*3),
|
||||
ResBlock(in_ch*3, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
|
||||
)
|
||||
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
|
||||
f0_warp = warp(f0, flow0_in)
|
||||
f1_warp = warp(f1, flow1_in)
|
||||
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)
|
||||
out = self.convblock(f_in)
|
||||
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
||||
ft_ = out[:, 4:, ...]
|
||||
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
|
||||
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
|
||||
return flow0, flow1, ft_
|
||||
69
networks/blocks/multi_flow.py
Executable file
69
networks/blocks/multi_flow.py
Executable file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.utils.flow_utils import warp
|
||||
from networks.blocks.ifrnet import (
|
||||
convrelu, resize,
|
||||
ResBlock,
|
||||
)
|
||||
|
||||
|
||||
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
|
||||
mask=None, img_res=None, mean=None):
|
||||
'''
|
||||
A parallel implementation of multiple flow field warping
|
||||
comb_block: An nn.Seqential object.
|
||||
img shape: [b, c, h, w]
|
||||
flow shape: [b, 2*num_flows, h, w]
|
||||
mask (opt):
|
||||
If 'mask' is None, the function conduct a simple average.
|
||||
img_res (opt):
|
||||
If 'img_res' is None, the function adds zero instead.
|
||||
mean (opt):
|
||||
If 'mean' is None, the function adds zero instead.
|
||||
'''
|
||||
b, c, h, w = flow0.shape
|
||||
num_flows = c // 2
|
||||
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
|
||||
mask = mask.reshape(b, num_flows, 1, h, w
|
||||
).reshape(-1, 1, h, w) if mask is not None else None
|
||||
img_res = img_res.reshape(b, num_flows, 3, h, w
|
||||
).reshape(-1, 3, h, w) if img_res is not None else 0
|
||||
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
|
||||
) if mean is not None else 0
|
||||
|
||||
img0_warp = warp(img0, flow0)
|
||||
img1_warp = warp(img1, flow1)
|
||||
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
||||
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
||||
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
|
||||
return imgt_pred
|
||||
|
||||
|
||||
class MultiFlowDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch, num_flows=3):
|
||||
super(MultiFlowDecoder, self).__init__()
|
||||
self.num_flows = num_flows
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch*3+4, in_ch*3),
|
||||
ResBlock(in_ch*3, skip_ch),
|
||||
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, flow0, flow1):
|
||||
n = self.num_flows
|
||||
f0_warp = warp(f0, flow0)
|
||||
f1_warp = warp(f1, flow1)
|
||||
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
|
||||
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
|
||||
mask = torch.sigmoid(mask)
|
||||
|
||||
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
|
||||
).repeat(1, self.num_flows, 1, 1)
|
||||
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
|
||||
).repeat(1, self.num_flows, 1, 1)
|
||||
|
||||
return flow0, flow1, mask, img_res
|
||||
207
networks/blocks/raft.py
Executable file
207
networks/blocks/raft.py
Executable file
@@ -0,0 +1,207 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def resize(x, scale_factor):
|
||||
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device),
|
||||
torch.arange(wd, device=device),
|
||||
indexing='ij')
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
|
||||
corr_levels=4, radius=3, scale_factor=None):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
cor_planes = corr_levels * (2 * radius + 1) **2
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
|
||||
|
||||
self.gru = nn.Sequential(
|
||||
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.feat_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.flow_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, 4, 3, padding=1),
|
||||
)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, net, flow, corr):
|
||||
net = resize(net, 1 / self.scale_factor
|
||||
) if self.scale_factor is not None else net
|
||||
cor = self.lrelu(self.convc1(corr))
|
||||
flo = self.lrelu(self.convf1(flow))
|
||||
flo = self.lrelu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
inp = self.lrelu(self.conv(cor_flo))
|
||||
inp = torch.cat([inp, flow, net], dim=1)
|
||||
|
||||
out = self.gru(inp)
|
||||
delta_net = self.feat_head(out)
|
||||
delta_flow = self.flow_head(out)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
||||
|
||||
return delta_net, delta_flow
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
|
||||
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
cor_planes = corr_levels * (2 * radius + 1) **2
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
|
||||
|
||||
self.gru = nn.Sequential(
|
||||
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.feat_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.flow_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
|
||||
)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, net, flow, corr):
|
||||
net = resize(net, 1 / self.scale_factor
|
||||
) if self.scale_factor is not None else net
|
||||
cor = self.lrelu(self.convc1(corr))
|
||||
cor = self.lrelu(self.convc2(cor))
|
||||
flo = self.lrelu(self.convf1(flow))
|
||||
flo = self.lrelu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
inp = self.lrelu(self.conv(cor_flo))
|
||||
inp = torch.cat([inp, flow, net], dim=1)
|
||||
|
||||
out = self.gru(inp)
|
||||
delta_net = self.feat_head(out)
|
||||
delta_flow = self.flow_head(out)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
||||
return delta_net, delta_flow
|
||||
|
||||
|
||||
class BidirCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
self.corr_pyramid_T = []
|
||||
|
||||
corr = BidirCorrBlock.corr(fmap1, fmap2)
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
|
||||
|
||||
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||
corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
self.corr_pyramid_T.append(corr_T)
|
||||
|
||||
for _ in range(self.num_levels-1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
self.corr_pyramid_T.append(corr_T)
|
||||
|
||||
def __call__(self, coords0, coords1):
|
||||
r = self.radius
|
||||
coords0 = coords0.permute(0, 2, 3, 1)
|
||||
coords1 = coords1.permute(0, 2, 3, 1)
|
||||
assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
||||
batch, h1, w1, _ = coords0.shape
|
||||
|
||||
out_pyramid = []
|
||||
out_pyramid_T = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
corr_T = self.corr_pyramid_T[i]
|
||||
|
||||
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
||||
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
|
||||
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||
|
||||
centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
coords_lvl_0 = centroid_lvl_0 + delta_lvl
|
||||
coords_lvl_1 = centroid_lvl_1 + delta_lvl
|
||||
|
||||
corr = bilinear_sampler(corr, coords_lvl_0)
|
||||
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
|
||||
corr = corr.view(batch, h1, w1, -1)
|
||||
corr_T = corr_T.view(batch, h1, w1, -1)
|
||||
out_pyramid.append(corr)
|
||||
out_pyramid_T.append(corr_T)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
out_T = torch.cat(out_pyramid_T, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
Reference in New Issue
Block a user