Перенес networks внутрь src

This commit is contained in:
Viner Abubakirov
2026-04-01 21:41:05 +05:00
parent 829d0c8c59
commit 888cdb3151
18 changed files with 341 additions and 522 deletions

1
.gitignore vendored
View File

@@ -175,5 +175,6 @@ cython_debug/
.pypirc .pypirc
.DS_Store
source/ source/
output/ output/

View File

@@ -4,8 +4,8 @@ from pathlib import Path
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
from imageio import imread, imwrite
from src.utils import utils
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder from src.utils.padder import InputPadder
@@ -84,9 +84,16 @@ class ImageInterpolator:
) )
def interpolate(self, image1: Path, image2: Path, output_path: Path): def interpolate(self, image1: Path, image2: Path, output_path: Path):
"""
Interpolates between two images and saves the result.
Args:
image1 (Path): Path to the first input image (only png and jpg formats are supported)
image2 (Path): Path to the second input image (only png and jpg formats are supported)
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}") logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(utils.read(image1)).to(self.device) tensor1 = img2tensor(imread(image1)).to(self.device)
tensor2 = img2tensor(utils.read(image2)).to(self.device) tensor2 = img2tensor(imread(image2)).to(self.device)
logging.debug( logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
) )
@@ -115,7 +122,7 @@ class ImageInterpolator:
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated) (interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
utils.write(output_path, tensor2img(interpolated.cpu())) imwrite(output_path, tensor2img(interpolated.cpu()))
logging.debug(f"Saved interpolated image to: {output_path}") logging.debug(f"Saved interpolated image to: {output_path}")
def scale(self, height: int, width: int) -> float: def scale(self, height: int, width: int) -> float:

61
main.py
View File

@@ -1,11 +1,11 @@
import logging import logging
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Generator
import cv2 import cv2
from tqdm import tqdm from tqdm import tqdm
from time import perf_counter from time import perf_counter
from decimal import Decimal
from interpolator import get_device from interpolator import get_device
from interpolator import ImageInterpolator from interpolator import ImageInterpolator
@@ -39,68 +39,11 @@ def move_images(src_dir: str, interpolated_dir: str, output_dir: str):
index += 1 index += 1
def build_file_list(moved_dir: str, list_path: str):
import os
moved_dir = Path(moved_dir)
frames = sorted(moved_dir.glob("img_*.png"))
print(frames[0])
with open(list_path, "w") as f:
for frame in frames:
f.write(f"file '{os.path.abspath(frame)}'\n")
def build_ffmpeg_file_list(frames_dir: str, interpolated_dir: str, list_path: str):
frames = sorted(Path(frames_dir).glob("img_*.png"))
interps = sorted(Path(interpolated_dir).glob("img_*.png"))
if len(interps) != len(frames) - 1:
raise ValueError("Interpolated frames must be N-1")
with open(list_path, "w") as f:
for i in range(len(frames)):
f.write(f"file '{frames[i].resolve().as_posix()}'\n")
if i < len(interps):
f.write(f"file '{interps[i].resolve().as_posix()}'\n")
def merge_with_ffmpeg(
original_video: str,
file_list: str,
output_video: str,
):
cap = cv2.VideoCapture(original_video)
if not cap.isOpened():
raise ValueError("Cannot open original video")
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
new_fps = Decimal(fps * 2)
cmd = [
"ffmpeg",
"-y",
"-r", str(new_fps.quantize(Decimal("1.0000000000"))),
"-f", "concat",
"-safe", "0",
"-i", file_list,
"-c:v", "libx264rgb",
output_video,
]
print("Running ffmpeg command:", " ".join(cmd))
subprocess.run(cmd, check=True)
def video_frames_to_disk_generator( def video_frames_to_disk_generator(
video_path: str | Path, video_path: str | Path,
output_dir: str | Path, output_dir: str | Path,
chunk_seconds: int = 10 chunk_seconds: int = 10
): ) -> Generator[tuple[Path, ...], None, None]:
output_dir = Path(output_dir) output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -1,69 +0,0 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from networks.blocks.ifrnet import (
convrelu, resize,
ResBlock,
)
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
mask=None, img_res=None, mean=None):
'''
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
'''
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = mask.reshape(b, num_flows, 1, h, w
).reshape(-1, 1, h, w) if mask is not None else None
img_res = img_res.reshape(b, num_flows, 3, h, w
).reshape(-1, 3, h, w) if img_res is not None else 0
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
) if mean is not None else 0
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch*3+4, in_ch*3),
ResBlock(in_ch*3, skip_ch),
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
).repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: networks.AMT-G.Model name: src.networks.AMT-G.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1 eval_interval: 1
network: network:
name: networks.AMT-S.Model name: src.networks.AMT-S.Model
params: params:
corr_radius: 3 corr_radius: 3
corr_lvls: 4 corr_lvls: 4

View File

@@ -1,9 +1,11 @@
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from networks.blocks.feat_enc import LargeEncoder from src.networks.blocks.feat_enc import LargeEncoder
from networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module): class Model(nn.Module):
@@ -42,7 +44,7 @@ class Model(nn.Module):
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
) )
def _get_updateblock(self, cdim, scale_factor=None): def _get_updateblock(self, cdim: int, scale_factor: Optional[float] = None):
return BasicUpdateBlock( return BasicUpdateBlock(
cdim=cdim, cdim=cdim,
hidden_dim=192, hidden_dim=192,
@@ -55,7 +57,15 @@ class Model(nn.Module):
radius=self.radius, radius=self.radius,
) )
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): def _corr_scale_lookup(
self,
corr_fn: BidirCorrBlock,
coord: torch.Tensor,
flow0: torch.Tensor,
flow1: torch.Tensor,
embt: torch.Tensor,
downsample: int = 1,
):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption # based on linear assumption
t1_scale = 1.0 / embt t1_scale = 1.0 / embt
@@ -70,7 +80,15 @@ class Model(nn.Module):
flow = torch.cat([flow0, flow1], dim=1) flow = torch.cat([flow0, flow1], dim=1)
return corr, flow return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(
self,
img0: torch.Tensor,
img1: torch.Tensor,
embt: torch.Tensor,
scale_factor: float = 1.0,
eval: bool = False,
**kwargs,
):
mean_ = ( mean_ = (
torch.cat([img0, img1], 2) torch.cat([img0, img1], 2)
.mean(1, keepdim=True) .mean(1, keepdim=True)

View File

@@ -1,38 +1,29 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from networks.blocks.raft import ( from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
coords_grid, from src.networks.blocks.feat_enc import BasicEncoder
BasicUpdateBlock, BidirCorrBlock from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
)
from networks.blocks.feat_enc import ( from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
BasicEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, def __init__(
corr_radius=3, self,
corr_lvls=4, corr_radius=3,
num_flows=5, corr_lvls=4,
channels=[48, 64, 72, 128], num_flows=5,
skip_channels=48 channels=[48, 64, 72, 128],
): skip_channels=48,
):
super(Model, self).__init__() super(Model, self).__init__()
self.radius = corr_radius self.radius = corr_radius
self.corr_levels = corr_lvls self.corr_levels = corr_lvls
self.num_flows = num_flows self.num_flows = num_flows
self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) self.feat_encoder = BasicEncoder(
output_dim=128, norm_fn="instance", dropout=0.0
)
self.encoder = Encoder([48, 64, 72, 128], large=True) self.encoder = Encoder([48, 64, 72, 128], large=True)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
@@ -45,22 +36,29 @@ class Model(nn.Module):
self.update2 = self._get_updateblock(48, 4.0) self.update2 = self._get_updateblock(48, 4.0)
self.comb_block = nn.Sequential( self.comb_block = nn.Sequential(
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
nn.PReLU(6*self.num_flows), nn.PReLU(6 * self.num_flows),
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
) )
def _get_updateblock(self, cdim, scale_factor=None): def _get_updateblock(self, cdim, scale_factor=None):
return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48, return BasicUpdateBlock(
corr_dim=256, corr_dim2=160, fc_dim=124, cdim=cdim,
scale_factor=scale_factor, corr_levels=self.corr_levels, hidden_dim=128,
radius=self.radius) flow_dim=48,
corr_dim=256,
corr_dim2=160,
fc_dim=124,
scale_factor=scale_factor,
corr_levels=self.corr_levels,
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption # based on linear assumption
t1_scale = 1. / embt t1_scale = 1.0 / embt
t0_scale = 1. / (1. - embt) t0_scale = 1.0 / (1.0 - embt)
if downsample != 1: if downsample != 1:
inv = 1 / downsample inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv) flow0 = inv * resize(flow0, scale_factor=inv)
@@ -72,7 +70,12 @@ class Model(nn.Module):
return corr, flow return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_ img0 = img0 - mean_
img1 = img1 - mean_ img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
@@ -80,8 +83,10 @@ class Model(nn.Module):
b, _, h, w = img0_.shape b, _, h, w = img0_.shape
coord = coords_grid(b, h // 8, w // 8, img0.device) coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) corr_fn = BidirCorrBlock(
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
@@ -90,9 +95,9 @@ class Model(nn.Module):
######################################### the 4th decoder ######################################### ######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, corr_4, flow_4 = self._corr_scale_lookup(
up_flow0_4, up_flow1_4, corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
embt, downsample=1) )
# residue update with lookup corr # residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
@@ -102,10 +107,12 @@ class Model(nn.Module):
ft_3_ = ft_3_ + delta_ft_3_ ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder ######################################### ######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
corr_3, flow_3 = self._corr_scale_lookup(corr_fn, ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
coord, up_flow0_3, up_flow1_3, )
embt, downsample=2) corr_3, flow_3 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
@@ -115,10 +122,12 @@ class Model(nn.Module):
ft_2_ = ft_2_ + delta_ft_2_ ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder ######################################### ######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
corr_2, flow_2 = self._corr_scale_lookup(corr_fn, ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
coord, up_flow0_2, up_flow1_2, )
embt, downsample=4) corr_2, flow_2 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
@@ -128,28 +137,36 @@ class Model(nn.Module):
ft_1_ = ft_1_ + delta_ft_1_ ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder ######################################### ######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
)
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 1.0 / scale_factor
mask = resize(mask, scale_factor=(1.0/scale_factor)) )
img_res = resize(img_res, scale_factor=(1.0/scale_factor)) up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions # Merge multiple predictions
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, imgt_pred = multi_flow_combine(
mask, img_res, mean_) self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
)
imgt_pred = torch.clamp(imgt_pred, 0, 1) imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval: if eval:
return { 'imgt_pred': imgt_pred, } return {
"imgt_pred": imgt_pred,
}
else: else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return { return {
'imgt_pred': imgt_pred, "imgt_pred": imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_], "ft_pred": [ft_1_, ft_2_, ft_3_],
} }

View File

@@ -1,31 +1,20 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from networks.blocks.raft import ( from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
coords_grid, from src.networks.blocks.feat_enc import SmallEncoder
SmallUpdateBlock, BidirCorrBlock from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
) from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from networks.blocks.feat_enc import (
SmallEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, def __init__(
corr_radius=3, self,
corr_lvls=4, corr_radius=3,
num_flows=3, corr_lvls=4,
channels=[20, 32, 44, 56], num_flows=3,
skip_channels=20): channels=[20, 32, 44, 56],
skip_channels=20,
):
super(Model, self).__init__() super(Model, self).__init__()
self.radius = corr_radius self.radius = corr_radius
self.corr_levels = corr_lvls self.corr_levels = corr_lvls
@@ -33,7 +22,7 @@ class Model(nn.Module):
self.channels = channels self.channels = channels
self.skip_channels = skip_channels self.skip_channels = skip_channels
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) self.feat_encoder = SmallEncoder(output_dim=84, norm_fn="instance", dropout=0.0)
self.encoder = Encoder(channels) self.encoder = Encoder(channels)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
@@ -46,21 +35,28 @@ class Model(nn.Module):
self.update2 = self._get_updateblock(20, 4) self.update2 = self._get_updateblock(20, 4)
self.comb_block = nn.Sequential( self.comb_block = nn.Sequential(
nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1), nn.Conv2d(3 * num_flows, 6 * num_flows, 3, 1, 1),
nn.PReLU(6*num_flows), nn.PReLU(6 * num_flows),
nn.Conv2d(6*num_flows, 3, 3, 1, 1), nn.Conv2d(6 * num_flows, 3, 3, 1, 1),
) )
def _get_updateblock(self, cdim, scale_factor=None): def _get_updateblock(self, cdim, scale_factor=None):
return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64, return SmallUpdateBlock(
fc_dim=68, scale_factor=scale_factor, cdim=cdim,
corr_levels=self.corr_levels, radius=self.radius) hidden_dim=76,
flow_dim=20,
corr_dim=64,
fc_dim=68,
scale_factor=scale_factor,
corr_levels=self.corr_levels,
radius=self.radius,
)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption # based on linear assumption
t1_scale = 1. / embt t1_scale = 1.0 / embt
t0_scale = 1. / (1. - embt) t0_scale = 1.0 / (1.0 - embt)
if downsample != 1: if downsample != 1:
inv = 1 / downsample inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv) flow0 = inv * resize(flow0, scale_factor=inv)
@@ -72,7 +68,12 @@ class Model(nn.Module):
return corr, flow return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_ img0 = img0 - mean_
img1 = img1 - mean_ img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
@@ -80,8 +81,10 @@ class Model(nn.Module):
b, _, h, w = img0_.shape b, _, h, w = img0_.shape
coord = coords_grid(b, h // 8, w // 8, img0.device) coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) corr_fn = BidirCorrBlock(
fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels
)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
@@ -90,9 +93,9 @@ class Model(nn.Module):
######################################### the 4th decoder ######################################### ######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, corr_4, flow_4 = self._corr_scale_lookup(
up_flow0_4, up_flow1_4, corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1
embt, downsample=1) )
# residue update with lookup corr # residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
@@ -102,10 +105,12 @@ class Model(nn.Module):
ft_3_ = ft_3_ + delta_ft_3_ ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder ######################################### ######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(
corr_3, flow_3 = self._corr_scale_lookup(corr_fn, ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4
coord, up_flow0_3, up_flow1_3, )
embt, downsample=2) corr_3, flow_3 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
@@ -115,10 +120,12 @@ class Model(nn.Module):
ft_2_ = ft_2_ + delta_ft_2_ ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder ######################################### ######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(
corr_2, flow_2 = self._corr_scale_lookup(corr_fn, ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3
coord, up_flow0_2, up_flow1_2, )
embt, downsample=4) corr_2, flow_2 = self._corr_scale_lookup(
corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4
)
# residue update with lookup corr # residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
@@ -128,27 +135,36 @@ class Model(nn.Module):
ft_1_ = ft_1_ + delta_ft_1_ ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder ######################################### ######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(
ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2
)
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 1.0 / scale_factor
mask = resize(mask, scale_factor=(1.0/scale_factor)) )
img_res = resize(img_res, scale_factor=(1.0/scale_factor)) up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
mask = resize(mask, scale_factor=(1.0 / scale_factor))
img_res = resize(img_res, scale_factor=(1.0 / scale_factor))
# Merge multiple predictions # Merge multiple predictions
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, imgt_pred = multi_flow_combine(
mask, img_res, mean_) self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_
)
imgt_pred = torch.clamp(imgt_pred, 0, 1) imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval: if eval:
return { 'imgt_pred': imgt_pred, } return {
"imgt_pred": imgt_pred,
}
else: else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return { return {
'imgt_pred': imgt_pred, "imgt_pred": imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_], "ft_pred": [ft_1_, ft_2_, ft_3_],
} }

View File

@@ -1,30 +1,23 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from src.utils.flow_utils import warp from src.utils.flow_utils import warp
from networks.blocks.ifrnet import ( from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
convrelu, resize,
ResBlock,
)
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self): def __init__(self):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.pyramid1 = nn.Sequential( self.pyramid1 = nn.Sequential(
convrelu(3, 32, 3, 2, 1), convrelu(3, 32, 3, 2, 1), convrelu(32, 32, 3, 1, 1)
convrelu(32, 32, 3, 1, 1)
) )
self.pyramid2 = nn.Sequential( self.pyramid2 = nn.Sequential(
convrelu(32, 48, 3, 2, 1), convrelu(32, 48, 3, 2, 1), convrelu(48, 48, 3, 1, 1)
convrelu(48, 48, 3, 1, 1)
) )
self.pyramid3 = nn.Sequential( self.pyramid3 = nn.Sequential(
convrelu(48, 72, 3, 2, 1), convrelu(48, 72, 3, 2, 1), convrelu(72, 72, 3, 1, 1)
convrelu(72, 72, 3, 1, 1)
) )
self.pyramid4 = nn.Sequential( self.pyramid4 = nn.Sequential(
convrelu(72, 96, 3, 2, 1), convrelu(72, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1)
convrelu(96, 96, 3, 1, 1)
) )
def forward(self, img): def forward(self, img):
@@ -39,9 +32,9 @@ class Decoder4(nn.Module):
def __init__(self): def __init__(self):
super(Decoder4, self).__init__() super(Decoder4, self).__init__()
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(192+1, 192), convrelu(192 + 1, 192),
ResBlock(192, 32), ResBlock(192, 32),
nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True) nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True),
) )
def forward(self, f0, f1, embt): def forward(self, f0, f1, embt):
@@ -58,7 +51,7 @@ class Decoder3(nn.Module):
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(220, 216), convrelu(220, 216),
ResBlock(216, 32), ResBlock(216, 32),
nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True) nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True),
) )
def forward(self, ft_, f0, f1, up_flow0, up_flow1): def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -75,7 +68,7 @@ class Decoder2(nn.Module):
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(148, 144), convrelu(148, 144),
ResBlock(144, 32), ResBlock(144, 32),
nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True) nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True),
) )
def forward(self, ft_, f0, f1, up_flow0, up_flow1): def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -92,7 +85,7 @@ class Decoder1(nn.Module):
self.convblock = nn.Sequential( self.convblock = nn.Sequential(
convrelu(100, 96), convrelu(100, 96),
ResBlock(96, 32), ResBlock(96, 32),
nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True) nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True),
) )
def forward(self, ft_, f0, f1, up_flow0, up_flow1): def forward(self, ft_, f0, f1, up_flow0, up_flow1):
@@ -113,7 +106,12 @@ class Model(nn.Module):
self.decoder1 = Decoder1() self.decoder1 = Decoder1()
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) mean_ = (
torch.cat([img0, img1], 2)
.mean(1, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
img0 = img0 - mean_ img0 = img0 - mean_
img1 = img1 - mean_ img1 = img1 - mean_
@@ -145,10 +143,14 @@ class Model(nn.Module):
up_res_1 = out1[:, 5:] up_res_1 = out1[:, 5:]
if scale_factor != 1.0: if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) 1.0 / scale_factor
up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) )
up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
1.0 / scale_factor
)
up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor))
up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor))
img0_warp = warp(img0, up_flow0_1) img0_warp = warp(img0, up_flow0_1)
img1_warp = warp(img1, up_flow1_1) img1_warp = warp(img1, up_flow1_1)
@@ -157,13 +159,15 @@ class Model(nn.Module):
imgt_pred = torch.clamp(imgt_pred, 0, 1) imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval: if eval:
return { 'imgt_pred': imgt_pred, } return {
"imgt_pred": imgt_pred,
}
else: else:
return { return {
'imgt_pred': imgt_pred, "imgt_pred": imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_], "ft_pred": [ft_1_, ft_2_, ft_3_],
'img0_warp': img0_warp, "img0_warp": img0_warp,
'img1_warp': img1_warp "img1_warp": img1_warp,
} }

View File

@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
"""
A parallel implementation of multiple flow field warping
comb_block: An nn.Seqential object.
img shape: [b, c, h, w]
flow shape: [b, 2*num_flows, h, w]
mask (opt):
If 'mask' is None, the function conduct a simple average.
img_res (opt):
If 'img_res' is None, the function adds zero instead.
mean (opt):
If 'mean' is None, the function adds zero instead.
"""
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
return imgt_pred
class MultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(MultiFlowDecoder, self).__init__()
self.num_flows = num_flows
self.convblock = nn.Sequential(
convrelu(in_ch * 3 + 4, in_ch * 3),
ResBlock(in_ch * 3, skip_ch),
nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True),
)
def forward(self, ft_, f0, f1, flow0, flow1):
n = self.num_flows
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
delta_flow0, delta_flow1, mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat(
1, self.num_flows, 1, 1
)
return flow0, flow1, mask, img_res

View File

@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
import importlib import importlib
from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from omegaconf import DictConfig from omegaconf import DictConfig

View File

@@ -1,199 +0,0 @@
import re
import sys
from pathlib import Path
import numpy as np
from imageio import imread, imwrite
def read(file: Path) -> np.ndarray:
readers = {
".float3": readFloat,
".flo": readFlow,
".ppm": readImage,
".pgm": readImage,
".png": readImage,
".jpg": readImage,
".pfm": lambda f: readPFM(f)[0],
}
func = readers.get(file.suffix.lower())
if func is None:
raise Exception("don't know how to read %s" % file)
return func(file)
def write(file: Path, data: np.ndarray) -> None:
writers = {
".float3": writeFloat,
".flo": writeFlow,
".ppm": writeImage,
".pgm": writeImage,
".png": writeImage,
".jpg": writeImage,
".pfm": writePFM,
}
func = writers.get(file.suffix.lower())
if func is None:
raise Exception("don't know how to write %s" % file)
return func(file, data)
def readPFM(file: Path):
data = open(file, "rb")
color = None
width = None
height = None
scale = None
endian = None
header = data.readline().rstrip()
if header.decode("ascii") == "PF":
color = True
elif header.decode("ascii") == "Pf":
color = False
else:
raise Exception("Not a PFM file.")
dim_match = re.match(r"^(\d+)\s(\d+)\s$", data.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception("Malformed PFM header.")
scale = float(data.readline().decode("ascii").rstrip())
if scale < 0:
endian = "<"
scale = -scale
else:
endian = ">"
result = np.fromfile(data, endian + "f")
shape = (height, width, 3) if color else (height, width)
result = np.reshape(result, shape)
result = np.flipud(result)
return result, scale
def writePFM(file: Path, image: np.ndarray, scale=1):
data = open(file, "wb")
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3:
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
data.write("PF\n" if color else "Pf\n".encode()) # type: ignore
data.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
data.write("%f\n".encode() % scale)
image.tofile(data)
def readFlow(file: Path):
if file.suffix.lower() == ".pfm":
return readPFM(file)[0][:, :, 0:2]
f = open(file, "rb")
header = f.read(4)
if header.decode("utf-8") != "PIEH":
raise Exception("Flow file header does not contain PIEH")
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
return flow.astype(np.float32)
def readImage(file: Path):
if file.suffix.lower() == ".pfm":
data = readPFM(file)[0]
if len(data.shape) == 3:
return data[:, :, 0:3]
else:
return data
return imread(file)
def writeImage(file: Path, data: np.ndarray):
if file.suffix.lower() == ".pfm":
return writePFM(file, data, 1)
return imwrite(file, data)
def writeFlow(file: Path, flow: np.ndarray):
f = open(file, "wb")
f.write("PIEH".encode("utf-8"))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def readFloat(file: Path):
f = open(file, "rb")
if (f.readline().decode("utf-8")) != "float\n":
raise Exception("float file %s did not contain <float> keyword" % file)
dim = int(f.readline())
dims = []
count = 1
for _ in range(0, dim):
d = int(f.readline())
dims.append(d)
count *= d
dims = list(reversed(dims))
data = np.fromfile(f, np.float32, count).reshape(dims)
if dim > 2:
data = np.transpose(data, (2, 1, 0))
data = np.transpose(data, (1, 0, 2))
return data
def writeFloat(file: Path, data: np.ndarray):
f = open(file, "wb")
dim = len(data.shape)
if dim > 3:
raise Exception("bad float file dimension: %d" % dim)
f.write(("float\n").encode("ascii"))
f.write(("%d\n" % dim).encode("ascii"))
if dim == 1:
f.write(("%d\n" % data.shape[0]).encode("ascii"))
else:
f.write(("%d\n" % data.shape[1]).encode("ascii"))
f.write(("%d\n" % data.shape[0]).encode("ascii"))
for i in range(2, dim):
f.write(("%d\n" % data.shape[i]).encode("ascii"))
data = data.astype(np.float32)
if dim == 2:
data.tofile(f)
else:
np.transpose(data, (2, 0, 1)).tofile(f)