Перенес networks внутрь src

This commit is contained in:
Viner Abubakirov
2026-04-01 21:41:05 +05:00
parent 829d0c8c59
commit 888cdb3151
18 changed files with 341 additions and 522 deletions

View File

@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
"""
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
"""
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch * 3 + 4, in_ch * 3),
ResBlock(in_ch * 3, skip_ch),
nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
return flow0, flow1, mask, img_res