Init code
This commit is contained in:
62
src/config/AMT-G.yaml
Executable file
62
src/config/AMT-G.yaml
Executable file
@@ -0,0 +1,62 @@
|
||||
exp_name: floloss1e-2_300epoch_bs24_lr1p5e-4
|
||||
seed: 2023
|
||||
epochs: 300
|
||||
distributed: true
|
||||
lr: 1.5e-4
|
||||
lr_min: 2e-5
|
||||
weight_decay: 0.0
|
||||
resume_state: null
|
||||
save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: networks.AMT-G.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
num_flows: 5
|
||||
data:
|
||||
train:
|
||||
name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset
|
||||
params:
|
||||
dataset_dir: data/vimeo_triplet
|
||||
val:
|
||||
name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset
|
||||
params:
|
||||
dataset_dir: data/vimeo_triplet
|
||||
train_loader:
|
||||
batch_size: 24
|
||||
num_workers: 12
|
||||
val_loader:
|
||||
batch_size: 24
|
||||
num_workers: 3
|
||||
|
||||
logger:
|
||||
use_wandb: true
|
||||
resume_id: null
|
||||
|
||||
losses:
|
||||
- {
|
||||
name: losses.loss.CharbonnierLoss,
|
||||
nickname: l_rec,
|
||||
params: {
|
||||
loss_weight: 1.0,
|
||||
keys: [imgt_pred, imgt]
|
||||
}
|
||||
}
|
||||
- {
|
||||
name: losses.loss.TernaryLoss,
|
||||
nickname: l_ter,
|
||||
params: {
|
||||
loss_weight: 1.0,
|
||||
keys: [imgt_pred, imgt]
|
||||
}
|
||||
}
|
||||
- {
|
||||
name: losses.loss.MultipleFlowLoss,
|
||||
nickname: l_flo,
|
||||
params: {
|
||||
loss_weight: 0.005,
|
||||
keys: [flow0_pred, flow1_pred, flow]
|
||||
}
|
||||
}
|
||||
63
src/config/AMT-S.yaml
Executable file
63
src/config/AMT-S.yaml
Executable file
@@ -0,0 +1,63 @@
|
||||
exp_name: floloss1e-2_300epoch_bs24_lr2e-4
|
||||
seed: 2023
|
||||
epochs: 300
|
||||
distributed: true
|
||||
lr: 2e-4
|
||||
lr_min: 2e-5
|
||||
weight_decay: 0.0
|
||||
resume_state: null
|
||||
save_dir: work_dir
|
||||
eval_interval: 1
|
||||
|
||||
network:
|
||||
name: networks.AMT-S.Model
|
||||
params:
|
||||
corr_radius: 3
|
||||
corr_lvls: 4
|
||||
num_flows: 3
|
||||
|
||||
data:
|
||||
train:
|
||||
name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset
|
||||
params:
|
||||
dataset_dir: data/vimeo_triplet
|
||||
val:
|
||||
name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset
|
||||
params:
|
||||
dataset_dir: data/vimeo_triplet
|
||||
train_loader:
|
||||
batch_size: 24
|
||||
num_workers: 12
|
||||
val_loader:
|
||||
batch_size: 24
|
||||
num_workers: 3
|
||||
|
||||
logger:
|
||||
use_wandb: false
|
||||
resume_id: null
|
||||
|
||||
losses:
|
||||
- {
|
||||
name: losses.loss.CharbonnierLoss,
|
||||
nickname: l_rec,
|
||||
params: {
|
||||
loss_weight: 1.0,
|
||||
keys: [imgt_pred, imgt]
|
||||
}
|
||||
}
|
||||
- {
|
||||
name: losses.loss.TernaryLoss,
|
||||
nickname: l_ter,
|
||||
params: {
|
||||
loss_weight: 1.0,
|
||||
keys: [imgt_pred, imgt]
|
||||
}
|
||||
}
|
||||
- {
|
||||
name: losses.loss.MultipleFlowLoss,
|
||||
nickname: l_flo,
|
||||
params: {
|
||||
loss_weight: 0.002,
|
||||
keys: [flow0_pred, flow1_pred, flow]
|
||||
}
|
||||
}
|
||||
BIN
src/pretrained/amt-g.pth
Normal file
BIN
src/pretrained/amt-g.pth
Normal file
Binary file not shown.
15
src/utils/build.py
Normal file
15
src/utils/build.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import importlib
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def base_build_fn(module: str, cls: str, params: dict):
|
||||
return getattr(importlib.import_module(module, package=None), cls)(**params)
|
||||
|
||||
|
||||
def build_from_cfg(config: "DictConfig"):
|
||||
module, cls = config["name"].rsplit(".", 1)
|
||||
params: dict = config.get("params", {})
|
||||
return base_build_fn(module, cls, params)
|
||||
25
src/utils/flow_utils.py
Normal file
25
src/utils/flow_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def warp(img, flow):
|
||||
B, _, H, W = flow.shape
|
||||
xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
|
||||
yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
|
||||
grid = torch.cat([xx, yy], 1).to(img)
|
||||
flow_ = torch.cat(
|
||||
[
|
||||
flow[:, 0:1, :, :] / ((W - 1.0) / 2.0),
|
||||
flow[:, 1:2, :, :] / ((H - 1.0) / 2.0),
|
||||
],
|
||||
1,
|
||||
)
|
||||
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
||||
output = F.grid_sample(
|
||||
input=img,
|
||||
grid=grid_,
|
||||
mode="bilinear",
|
||||
padding_mode="border",
|
||||
align_corners=True,
|
||||
)
|
||||
return output
|
||||
35
src/utils/padder.py
Normal file
35
src/utils/padder.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
class InputPadder:
|
||||
"""Pads images such that dimensions are divisible by divisor"""
|
||||
|
||||
def __init__(self, dims: "torch.Size", divisor=16):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
|
||||
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
|
||||
self._pad = [
|
||||
pad_wd // 2,
|
||||
pad_wd - pad_wd // 2,
|
||||
pad_ht // 2,
|
||||
pad_ht - pad_ht // 2,
|
||||
]
|
||||
|
||||
def pad(self, *inputs: "torch.Tensor"):
|
||||
if len(inputs) == 1:
|
||||
return F.pad(inputs[0], self._pad, mode="replicate")
|
||||
else:
|
||||
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||
|
||||
def unpad(self, *inputs: "torch.Tensor"):
|
||||
return [self._unpad(x) for x in inputs]
|
||||
|
||||
def _unpad(self, x: "torch.Tensor"):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||
56
src/utils/torch.py
Normal file
56
src/utils/torch.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tensor2img(tensor: torch.Tensor):
|
||||
return (
|
||||
(tensor * 255.0)
|
||||
.detach()
|
||||
.squeeze(0)
|
||||
.permute(1, 2, 0)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
|
||||
|
||||
def img2tensor(img: np.ndarray) -> torch.Tensor:
|
||||
logging.debug(f"Converting image of shape {img.shape} to tensor")
|
||||
if img.shape[-1] > 3:
|
||||
img = img[:, :, :3]
|
||||
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
|
||||
|
||||
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]:
|
||||
logging.debug("Checking dimensions of input tensors")
|
||||
shape_list = []
|
||||
result = list(args)
|
||||
for t in args:
|
||||
logging.debug(f"Tensor shape: {t.shape}")
|
||||
shape_list.append(t.shape[2:])
|
||||
|
||||
if len(set(shape_list)) > 1:
|
||||
logging.warning(
|
||||
"Inconsistent tensor shapes detected. Resizing tensors to the same shape."
|
||||
)
|
||||
desired_shape = shape_list[0]
|
||||
logging.info(
|
||||
f"Inconsistent size of input video frames. All frames will be resized to {desired_shape}"
|
||||
)
|
||||
|
||||
resize_tensor_list = []
|
||||
for t in args:
|
||||
resize_tensor_list.append(
|
||||
torch.nn.functional.interpolate(
|
||||
input=t,
|
||||
size=tuple(desired_shape),
|
||||
mode="bilinear",
|
||||
)
|
||||
)
|
||||
|
||||
result = resize_tensor_list
|
||||
|
||||
return result
|
||||
199
src/utils/utils.py
Normal file
199
src/utils/utils.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from imageio import imread, imwrite
|
||||
|
||||
|
||||
def read(file: Path) -> np.ndarray:
|
||||
readers = {
|
||||
".float3": readFloat,
|
||||
".flo": readFlow,
|
||||
".ppm": readImage,
|
||||
".pgm": readImage,
|
||||
".png": readImage,
|
||||
".jpg": readImage,
|
||||
".pfm": lambda f: readPFM(f)[0],
|
||||
}
|
||||
func = readers.get(file.suffix.lower())
|
||||
if func is None:
|
||||
raise Exception("don't know how to read %s" % file)
|
||||
return func(file)
|
||||
|
||||
|
||||
def write(file: Path, data: np.ndarray) -> None:
|
||||
writers = {
|
||||
".float3": writeFloat,
|
||||
".flo": writeFlow,
|
||||
".ppm": writeImage,
|
||||
".pgm": writeImage,
|
||||
".png": writeImage,
|
||||
".jpg": writeImage,
|
||||
".pfm": writePFM,
|
||||
}
|
||||
func = writers.get(file.suffix.lower())
|
||||
if func is None:
|
||||
raise Exception("don't know how to write %s" % file)
|
||||
return func(file, data)
|
||||
|
||||
|
||||
def readPFM(file: Path):
|
||||
data = open(file, "rb")
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = data.readline().rstrip()
|
||||
if header.decode("ascii") == "PF":
|
||||
color = True
|
||||
elif header.decode("ascii") == "Pf":
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Not a PFM file.")
|
||||
|
||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", data.readline().decode("ascii"))
|
||||
if dim_match:
|
||||
width, height = list(map(int, dim_match.groups()))
|
||||
else:
|
||||
raise Exception("Malformed PFM header.")
|
||||
|
||||
scale = float(data.readline().decode("ascii").rstrip())
|
||||
if scale < 0:
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
endian = ">"
|
||||
|
||||
result = np.fromfile(data, endian + "f")
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
result = np.reshape(result, shape)
|
||||
result = np.flipud(result)
|
||||
return result, scale
|
||||
|
||||
|
||||
def writePFM(file: Path, image: np.ndarray, scale=1):
|
||||
data = open(file, "wb")
|
||||
|
||||
color = None
|
||||
|
||||
if image.dtype.name != "float32":
|
||||
raise Exception("Image dtype must be float32.")
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
color = True
|
||||
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||
|
||||
data.write("PF\n" if color else "Pf\n".encode()) # type: ignore
|
||||
data.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||
scale = -scale
|
||||
|
||||
data.write("%f\n".encode() % scale)
|
||||
|
||||
image.tofile(data)
|
||||
|
||||
|
||||
def readFlow(file: Path):
|
||||
if file.suffix.lower() == ".pfm":
|
||||
return readPFM(file)[0][:, :, 0:2]
|
||||
|
||||
f = open(file, "rb")
|
||||
|
||||
header = f.read(4)
|
||||
if header.decode("utf-8") != "PIEH":
|
||||
raise Exception("Flow file header does not contain PIEH")
|
||||
|
||||
width = np.fromfile(f, np.int32, 1).squeeze()
|
||||
height = np.fromfile(f, np.int32, 1).squeeze()
|
||||
|
||||
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
|
||||
|
||||
return flow.astype(np.float32)
|
||||
|
||||
|
||||
def readImage(file: Path):
|
||||
if file.suffix.lower() == ".pfm":
|
||||
data = readPFM(file)[0]
|
||||
if len(data.shape) == 3:
|
||||
return data[:, :, 0:3]
|
||||
else:
|
||||
return data
|
||||
return imread(file)
|
||||
|
||||
|
||||
def writeImage(file: Path, data: np.ndarray):
|
||||
if file.suffix.lower() == ".pfm":
|
||||
return writePFM(file, data, 1)
|
||||
return imwrite(file, data)
|
||||
|
||||
|
||||
def writeFlow(file: Path, flow: np.ndarray):
|
||||
f = open(file, "wb")
|
||||
f.write("PIEH".encode("utf-8"))
|
||||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
||||
flow = flow.astype(np.float32)
|
||||
flow.tofile(f)
|
||||
|
||||
|
||||
def readFloat(file: Path):
|
||||
f = open(file, "rb")
|
||||
|
||||
if (f.readline().decode("utf-8")) != "float\n":
|
||||
raise Exception("float file %s did not contain <float> keyword" % file)
|
||||
|
||||
dim = int(f.readline())
|
||||
|
||||
dims = []
|
||||
count = 1
|
||||
for _ in range(0, dim):
|
||||
d = int(f.readline())
|
||||
dims.append(d)
|
||||
count *= d
|
||||
|
||||
dims = list(reversed(dims))
|
||||
|
||||
data = np.fromfile(f, np.float32, count).reshape(dims)
|
||||
if dim > 2:
|
||||
data = np.transpose(data, (2, 1, 0))
|
||||
data = np.transpose(data, (1, 0, 2))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def writeFloat(file: Path, data: np.ndarray):
|
||||
f = open(file, "wb")
|
||||
|
||||
dim = len(data.shape)
|
||||
if dim > 3:
|
||||
raise Exception("bad float file dimension: %d" % dim)
|
||||
|
||||
f.write(("float\n").encode("ascii"))
|
||||
f.write(("%d\n" % dim).encode("ascii"))
|
||||
|
||||
if dim == 1:
|
||||
f.write(("%d\n" % data.shape[0]).encode("ascii"))
|
||||
else:
|
||||
f.write(("%d\n" % data.shape[1]).encode("ascii"))
|
||||
f.write(("%d\n" % data.shape[0]).encode("ascii"))
|
||||
for i in range(2, dim):
|
||||
f.write(("%d\n" % data.shape[i]).encode("ascii"))
|
||||
|
||||
data = data.astype(np.float32)
|
||||
if dim == 2:
|
||||
data.tofile(f)
|
||||
|
||||
else:
|
||||
np.transpose(data, (2, 0, 1)).tofile(f)
|
||||
Reference in New Issue
Block a user