diff --git a/.gitignore b/.gitignore index 7d74ea7..85d2544 100644 --- a/.gitignore +++ b/.gitignore @@ -175,5 +175,6 @@ cython_debug/ .pypirc +.DS_Store source/ output/ \ No newline at end of file diff --git a/main.py b/main.py index fd8256f..c4bf696 100644 --- a/main.py +++ b/main.py @@ -1,149 +1,26 @@ import logging -import subprocess from pathlib import Path +from typing import TYPE_CHECKING -import cv2 -from tqdm import tqdm -from time import perf_counter -from decimal import Decimal +import tqdm -from interpolator import get_device -from interpolator import ImageInterpolator -from interpolator import ModelRunner, Anchor - - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +from src.config import presets +from src.utils.fs import FileSystem +from src.utils.video import VideoMaker +from src.interpolator import ( + ImageInterpolator, + Anchor, + get_device, + get_vram_available, + ModelRunner, ) -from pathlib import Path + +if TYPE_CHECKING: + import torch -def move_images(src_dir: str, interpolated_dir: str, output_dir: str): - src_dir = Path(src_dir) - interpolated_dir = Path(interpolated_dir) - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - index = 0 - src_frames = sorted(src_dir.glob("img_*.png")) - interp_frames = sorted(interpolated_dir.glob("img_*.png")) - for i in range(len(src_frames)): - output_frame = output_dir / f"img_{index:08d}.png" - src_frames[i].rename(output_frame) - index += 1 - - if i < len(interp_frames): - output_interp = output_dir / f"img_{index:08d}.png" - interp_frames[i].rename(output_interp) - index += 1 - - -def build_file_list(moved_dir: str, list_path: str): - import os - moved_dir = Path(moved_dir) - frames = sorted(moved_dir.glob("img_*.png")) - print(frames[0]) - - with open(list_path, "w") as f: - for frame in frames: - f.write(f"file '{os.path.abspath(frame)}'\n") - - -def build_ffmpeg_file_list(frames_dir: str, interpolated_dir: str, list_path: str): - frames = sorted(Path(frames_dir).glob("img_*.png")) - interps = sorted(Path(interpolated_dir).glob("img_*.png")) - - if len(interps) != len(frames) - 1: - raise ValueError("Interpolated frames must be N-1") - - with open(list_path, "w") as f: - for i in range(len(frames)): - f.write(f"file '{frames[i].resolve().as_posix()}'\n") - - if i < len(interps): - f.write(f"file '{interps[i].resolve().as_posix()}'\n") - - -def merge_with_ffmpeg( - original_video: str, - file_list: str, - output_video: str, -): - cap = cv2.VideoCapture(original_video) - - if not cap.isOpened(): - raise ValueError("Cannot open original video") - - fps = cap.get(cv2.CAP_PROP_FPS) - cap.release() - - new_fps = Decimal(fps * 2) - - cmd = [ - "ffmpeg", - "-y", - "-r", str(new_fps.quantize(Decimal("1.0000000000"))), - "-f", "concat", - "-safe", "0", - "-i", file_list, - "-c:v", "libx264rgb", - output_video, - ] - print("Running ffmpeg command:", " ".join(cmd)) - - subprocess.run(cmd, check=True) - - - -def video_frames_to_disk_generator( - video_path: str | Path, - output_dir: str | Path, - chunk_seconds: int = 10 -): - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - cap = cv2.VideoCapture(str(video_path)) - - if not cap.isOpened(): - raise ValueError(f"Cannot open video: {video_path}") - - fps = cap.get(cv2.CAP_PROP_FPS) - frames_per_chunk = int(fps * chunk_seconds) - - frame_index = 0 - - while True: - paths = [] - - for _ in range(frames_per_chunk): - ret, frame = cap.read() - if not ret: - cap.release() - return - - frame_path = output_dir / f"img_{frame_index:08d}.png" - cv2.imwrite(str(frame_path), frame) - - paths.append(frame_path) - frame_index += 1 - - yield tuple(paths) - - -def main(): - start = perf_counter() - logging.info("Starting video interpolation process") - config_path = Path("src/config/AMT-G.yaml") - ckpt_path = Path("src/pretrained/amt-g.pth") - video_path = Path("example/video.mp4") - output_dir = Path("output/frames") - output_interpolated_dir = Path("output/interpolated") - output_interpolated_dir.mkdir(parents=True, exist_ok=True) - - device = get_device() - model_runner = ModelRunner(config_path, ckpt_path, device) +def performing_warning_message(device: "torch.device"): if device.type in ("cpu", "mps"): if device.type == "mps": logging.warning( @@ -153,87 +30,199 @@ def main(): logging.warning( "Running on CPU may be very slow. Consider using a GPU for better performance." ) - anchor = Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) elif device.type == "cuda": - anchor = Anchor( + pass + else: + raise Exception(f"Unsupported device type: {device.type}") + + +def init_fs(base_path: Path) -> FileSystem: + fs = FileSystem(base_path) + fs.clear_directory(fs.frames_path) + fs.clear_directory(fs.interpolated_path) + fs.clear_directory(fs.moved_path) + fs.clear_directory(fs.video_part_path) + return fs + + +def init_video_maker() -> VideoMaker: + return VideoMaker() + + +def init_device() -> "torch.device": + device = get_device() + performing_warning_message(device) + vram_available = get_vram_available(device) + logging.info(f"Available VRAM: {vram_available / (1024 ** 3):.2f} GB") + return device + + +def init_anchor(device: "torch.device") -> Anchor: + if device.type in ("cpu", "mps"): + return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) + elif device.type == "cuda": + return Anchor( resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 ) else: raise Exception(f"Unsupported device type: {device.type}") - - interpolator = ImageInterpolator(device, anchor, model_runner) - - loaded_time = perf_counter() - start - logging.info(f"Model loaded and initialized in {loaded_time:.2f} seconds") - - prev_frame_path = None - frame_count = 0 - for frame_paths in video_frames_to_disk_generator(video_path, output_dir): - logging.info(f"Processing frames: {len(frame_paths)}") - - if prev_frame_path is not None: - img1 = prev_frame_path[-1] - img2 = frame_paths[0] - output_path = output_interpolated_dir / f"img_{frame_count:08d}.png" - interpolator.interpolate(img1, img2, output_path) - logging.debug(f"Interpolated image saved to: {output_path}") - frame_count += 1 - for i in tqdm(range(len(frame_paths) - 1), desc="Interpolating frames"): - img1 = frame_paths[i] - img2 = frame_paths[i + 1] - output_path = output_interpolated_dir / f"img_{frame_count:08d}.png" - interpolator.interpolate(img1, img2, output_path) - logging.debug(f"Interpolated image saved to: {output_path}") - frame_count += 1 - prev_frame_path = frame_paths - total_time = perf_counter() - start - logging.info(f"Video interpolation completed in {total_time:.2f} seconds") -def builder(): - frames_dir = "output/frames" - interpolated_dir = "output/interpolated" - moved_dir = "output/moved" - video_path = "example/video.mp4" - output_video = "output/interpolated_video.mp4" - move_images(frames_dir, interpolated_dir, moved_dir) - - cap = cv2.VideoCapture(video_path) - - if not cap.isOpened(): - raise ValueError("Cannot open original video") - - fps = cap.get(cv2.CAP_PROP_FPS) - cmd = [ - "ffmpeg", - "-y", - "-framerate", str(fps * 2), - "-i", f"{moved_dir}/img_%08d.png", - "-i", video_path, - "-c:v", "libx264", - "-c:a", "copy", - "-shortest", - output_video, - ] - logging.info("Running ffmpeg command to build final video: " + " ".join(cmd)) - subprocess.run(cmd, check=True) +def init_model_runner( + config: Path, checkpoint_path: Path, device: "torch.device" +) -> ModelRunner: + return ModelRunner(config, checkpoint_path, device) -def cleanup(): - import os - import shutil - frames_dir = "output/frames" - interpolated_dir = "output/interpolated" - moved_dir = "output/moved" - os.makedirs(frames_dir, exist_ok=True) - os.makedirs(interpolated_dir, exist_ok=True) - os.makedirs(moved_dir, exist_ok=True) - shutil.rmtree(frames_dir) - shutil.rmtree(interpolated_dir) - shutil.rmtree(moved_dir) +def init_interpolator( + model_runner: ModelRunner, device: "torch.device" +) -> ImageInterpolator: + anchor = init_anchor(device) + return ImageInterpolator(device, anchor, model_runner) + + +class InterpolationPipeline: + def __init__( + self, + config: Path, + checkpoint_path: Path, + base_path: Path, + ): + self.fs = init_fs(base_path) + self.video_maker = init_video_maker() + self.device = init_device() + self.model_runner = init_model_runner(config, checkpoint_path, self.device) + self.interpolator = init_interpolator(self.model_runner, self.device) + + def run(self, video_path: Path, output_video: str): + prev_frame_path = None + frame_count = 0 + part = 0 + source_frame_length = 0 + chunk_seconds = 10 + length = self.video_maker.get_video_duration(video_path) + last_part_seconds = 1 if length % chunk_seconds else 0 + total_parts = int(length // chunk_seconds) + last_part_seconds + fps = self.video_maker.get_fps(video_path) + logging.info(f"Video FPS: {fps}") + fps *= 2 # Doubling FPS + for frame_paths in self.video_maker.video_to_frames_generator( + video_path, self.fs.frames_path, chunk_seconds + ): + logging.info(f"Processing frames: {len(frame_paths)}") + if prev_frame_path is not None: + img1 = prev_frame_path[-1] + img2 = frame_paths[0] + output_path = self.fs.interpolated_path / f"img_{frame_count:08d}.png" + self.interpolator.interpolate(img1, img2, output_path) + logging.debug(f"Interpolated image saved to: {output_path}") + self._merge_frames_to_video( + self.fs.video_part_path / f"video_{part:08d}.mp4", + fps, + source_frame_length=source_frame_length, + ) + logging.info(f"Finished processing part {part:08d}") + frame_count += 1 + part += 1 + for i in tqdm.tqdm( + range(len(frame_paths) - 1), + desc=f"Processing video frames {part + 1} / {total_parts}", + ): + img1 = frame_paths[i] + img2 = frame_paths[i + 1] + output_path = self.fs.interpolated_path / f"img_{i:08d}.png" + self.interpolator.interpolate(img1, img2, output_path) + logging.debug(f"Interpolated image saved to: {output_path}") + frame_count += 1 + source_frame_length = len(frame_paths) + prev_frame_path = frame_paths + + self._merge_frames_to_video( + self.fs.video_part_path / f"video_{part:08d}.mp4", + fps, + source_frame_length=source_frame_length, + ) + logging.info(f"Finished processing part {part:08d}") + self._merge_video_parts(self.fs.output_path / output_video) + logging.info( + f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" + ) + + def _merge_frames_to_video( + self, output_video: Path, fps: float, source_frame_length: int = 0 + ): + self._move_frames(source_frame_length) + self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) + + def _merge_video_parts(self, output_video: Path): + self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) + self.fs.clear_directory(self.fs.video_part_path) + + def _move_frames(self, source_frame_length: int = 0): + self.fs.clear_directory(self.fs.moved_path) + src_frames = sorted(self.fs.frames_path.glob("*.png")) + interpolated_frames = sorted(self.fs.interpolated_path.glob("*.png")) + index = 0 + for i in range(source_frame_length): + moved_frame_path = self.fs.moved_path / f"img_{index:08d}.png" + src_frames[i].rename(moved_frame_path) + index += 1 + if i < len(interpolated_frames): + moved_interpolated_path = self.fs.moved_path / f"img_{index:08d}.png" + interpolated_frames[i].rename(moved_interpolated_path) + index += 1 + logging.info( + f"Moved {len(src_frames)} source frames and {len(interpolated_frames)} interpolated frames to {self.fs.moved_path}" + ) + + +def runner( + base_path: Path, + video_path: Path, + output_video: str, + preset: presets.Preset = presets.LARGE, +): + pipeline = InterpolationPipeline( + config=preset.config, + checkpoint_path=preset.checkpoint, + base_path=base_path, + ) + pipeline.run(video_path, output_video) + + +def main(): + import argparse + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) + + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--base_path", help="Base path", default="output") + parser.add_argument( + "-v", "--video_path", help="Video path", default="example/video.mp4" + ) + parser.add_argument( + "-o", + "--output", + help="Output video name (example: 'interpolated_video.mp4')", + default="interpolated_video.mp4", + ) + parser.add_argument( + "-p", + "--preset", + help="Model preset", + choices=["small", "large", "global"], + default="global", + ) + args = parser.parse_args() + runner( + base_path=Path(args.base_path), + video_path=Path(args.video_path), + output_video=args.output, + preset=getattr(presets, args.preset.upper()) + ) + if __name__ == "__main__": - cleanup() main() - builder() - cleanup() diff --git a/networks/blocks/multi_flow.py b/networks/blocks/multi_flow.py deleted file mode 100755 index 734fab8..0000000 --- a/networks/blocks/multi_flow.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.nn as nn -from src.utils.flow_utils import warp -from networks.blocks.ifrnet import ( - convrelu, resize, - ResBlock, -) - - -def multi_flow_combine(comb_block, img0, img1, flow0, flow1, - mask=None, img_res=None, mean=None): - ''' - A parallel implementation of multiple flow field warping - comb_block: An nn.Seqential object. - img shape: [b, c, h, w] - flow shape: [b, 2*num_flows, h, w] - mask (opt): - If 'mask' is None, the function conduct a simple average. - img_res (opt): - If 'img_res' is None, the function adds zero instead. - mean (opt): - If 'mean' is None, the function adds zero instead. - ''' - b, c, h, w = flow0.shape - num_flows = c // 2 - flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) - - mask = mask.reshape(b, num_flows, 1, h, w - ).reshape(-1, 1, h, w) if mask is not None else None - img_res = img_res.reshape(b, num_flows, 3, h, w - ).reshape(-1, 3, h, w) if img_res is not None else 0 - img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) - img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) - mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 - ) if mean is not None else 0 - - img0_warp = warp(img0, flow0) - img1_warp = warp(img1, flow1) - img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res - img_warps = img_warps.reshape(b, num_flows, 3, h, w) - imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) - return imgt_pred - - -class MultiFlowDecoder(nn.Module): - def __init__(self, in_ch, skip_ch, num_flows=3): - super(MultiFlowDecoder, self).__init__() - self.num_flows = num_flows - self.convblock = nn.Sequential( - convrelu(in_ch*3+4, in_ch*3), - ResBlock(in_ch*3, skip_ch), - nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) - ) - - def forward(self, ft_, f0, f1, flow0, flow1): - n = self.num_flows - f0_warp = warp(f0, flow0) - f1_warp = warp(f1, flow1) - out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) - delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) - mask = torch.sigmoid(mask) - - flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 - ).repeat(1, self.num_flows, 1, 1) - flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 - ).repeat(1, self.num_flows, 1, 1) - - return flow0, flow1, mask, img_res \ No newline at end of file diff --git a/src/config/AMT-G.yaml b/src/config/AMT-G.yaml index 7b3bb39..4570e2b 100755 --- a/src/config/AMT-G.yaml +++ b/src/config/AMT-G.yaml @@ -10,7 +10,7 @@ save_dir: work_dir eval_interval: 1 network: - name: networks.AMT-G.Model + name: src.networks.AMT-G.Model params: corr_radius: 3 corr_lvls: 4 diff --git a/src/config/AMT-L.yaml b/src/config/AMT-L.yaml new file mode 100644 index 0000000..19d70ca --- /dev/null +++ b/src/config/AMT-L.yaml @@ -0,0 +1,62 @@ +exp_name: floloss1e-2_300epoch_bs24_lr2e-4 +seed: 2023 +epochs: 300 +distributed: true +lr: 2e-4 +lr_min: 2e-5 +weight_decay: 0.0 +resume_state: null +save_dir: work_dir +eval_interval: 1 + +network: + name: src.networks.AMT-L.Model + params: + corr_radius: 3 + corr_lvls: 4 + num_flows: 5 +data: + train: + name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset + params: + dataset_dir: data/vimeo_triplet + val: + name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset + params: + dataset_dir: data/vimeo_triplet + train_loader: + batch_size: 24 + num_workers: 12 + val_loader: + batch_size: 24 + num_workers: 3 + +logger: + use_wandb: true + resume_id: null + +losses: + - { + name: losses.loss.CharbonnierLoss, + nickname: l_rec, + params: { + loss_weight: 1.0, + keys: [imgt_pred, imgt] + } + } + - { + name: losses.loss.TernaryLoss, + nickname: l_ter, + params: { + loss_weight: 1.0, + keys: [imgt_pred, imgt] + } + } + - { + name: losses.loss.MultipleFlowLoss, + nickname: l_flo, + params: { + loss_weight: 0.002, + keys: [flow0_pred, flow1_pred, flow] + } + } \ No newline at end of file diff --git a/src/config/AMT-S.yaml b/src/config/AMT-S.yaml index f067355..d763a5a 100755 --- a/src/config/AMT-S.yaml +++ b/src/config/AMT-S.yaml @@ -10,7 +10,7 @@ save_dir: work_dir eval_interval: 1 network: - name: networks.AMT-S.Model + name: src.networks.AMT-S.Model params: corr_radius: 3 corr_lvls: 4 diff --git a/src/config/presets.py b/src/config/presets.py new file mode 100644 index 0000000..7687a1b --- /dev/null +++ b/src/config/presets.py @@ -0,0 +1,24 @@ +from pathlib import Path +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Preset: + config: Path + checkpoint: Path + + +SMALL = Preset( + config=Path("src/config/AMT-S.yaml"), + checkpoint=Path("src/pretrained/amt-s.pth"), +) + +LARGE = Preset( + config=Path("src/config/AMT-L.yaml"), + checkpoint=Path("src/pretrained/amt-l.pth"), +) + +GLOBAL = Preset( + config=Path("src/config/AMT-g.yaml"), + checkpoint=Path("src/pretrained/amt-g.pth"), +) diff --git a/interpolator.py b/src/interpolator.py similarity index 89% rename from interpolator.py rename to src/interpolator.py index a65f07e..41a5848 100644 --- a/interpolator.py +++ b/src/interpolator.py @@ -4,8 +4,8 @@ from pathlib import Path import torch import numpy as np from omegaconf import OmegaConf, DictConfig +from imageio import imread, imwrite -from src.utils import utils from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img from src.utils.build import build_from_cfg from src.utils.padder import InputPadder @@ -84,9 +84,16 @@ class ImageInterpolator: ) def interpolate(self, image1: Path, image2: Path, output_path: Path): + """ + Interpolates between two images and saves the result. + Args: + image1 (Path): Path to the first input image (only png and jpg formats are supported) + image2 (Path): Path to the second input image (only png and jpg formats are supported) + output_path (Path): Path to save the interpolated image (only png and jpg formats are supported) + """ logging.debug(f"Reading images: {image1} and {image2}") - tensor1 = img2tensor(utils.read(image1)).to(self.device) - tensor2 = img2tensor(utils.read(image2)).to(self.device) + tensor1 = img2tensor(imread(image1)).to(self.device) + tensor2 = img2tensor(imread(image2)).to(self.device) logging.debug( f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}" ) @@ -115,7 +122,7 @@ class ImageInterpolator: logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}") (interpolated,) = padder.unpad(interpolated) logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}") - utils.write(output_path, tensor2img(interpolated.cpu())) + imwrite(output_path, tensor2img(interpolated.cpu())) logging.debug(f"Saved interpolated image to: {output_path}") def scale(self, height: int, width: int) -> float: diff --git a/networks/AMT-G.py b/src/networks/AMT-G.py similarity index 89% rename from networks/AMT-G.py rename to src/networks/AMT-G.py index 157f8b8..3f3de9e 100755 --- a/networks/AMT-G.py +++ b/src/networks/AMT-G.py @@ -1,9 +1,11 @@ +from typing import Optional + import torch import torch.nn as nn -from networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock -from networks.blocks.feat_enc import LargeEncoder -from networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder -from networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder +from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock +from src.networks.blocks.feat_enc import LargeEncoder +from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder +from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder class Model(nn.Module): @@ -42,7 +44,7 @@ class Model(nn.Module): nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), ) - def _get_updateblock(self, cdim, scale_factor=None): + def _get_updateblock(self, cdim: int, scale_factor: Optional[float] = None): return BasicUpdateBlock( cdim=cdim, hidden_dim=192, @@ -55,7 +57,15 @@ class Model(nn.Module): radius=self.radius, ) - def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + def _corr_scale_lookup( + self, + corr_fn: BidirCorrBlock, + coord: torch.Tensor, + flow0: torch.Tensor, + flow1: torch.Tensor, + embt: torch.Tensor, + downsample: int = 1, + ): # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # based on linear assumption t1_scale = 1.0 / embt @@ -70,7 +80,15 @@ class Model(nn.Module): flow = torch.cat([flow0, flow1], dim=1) return corr, flow - def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + def forward( + self, + img0: torch.Tensor, + img1: torch.Tensor, + embt: torch.Tensor, + scale_factor: float = 1.0, + eval: bool = False, + **kwargs, + ): mean_ = ( torch.cat([img0, img1], 2) .mean(1, keepdim=True) diff --git a/networks/AMT-L.py b/src/networks/AMT-L.py similarity index 53% rename from networks/AMT-L.py rename to src/networks/AMT-L.py index f992688..6243a4d 100755 --- a/networks/AMT-L.py +++ b/src/networks/AMT-L.py @@ -1,40 +1,31 @@ import torch import torch.nn as nn -from networks.blocks.raft import ( - coords_grid, - BasicUpdateBlock, BidirCorrBlock -) -from networks.blocks.feat_enc import ( - BasicEncoder -) -from networks.blocks.ifrnet import ( - resize, - Encoder, - InitDecoder, - IntermediateDecoder -) -from networks.blocks.multi_flow import ( - multi_flow_combine, - MultiFlowDecoder -) +from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock +from src.networks.blocks.feat_enc import BasicEncoder +from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder + +from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder class Model(nn.Module): - def __init__(self, - corr_radius=3, - corr_lvls=4, - num_flows=5, - channels=[48, 64, 72, 128], - skip_channels=48 - ): + def __init__( + self, + corr_radius=3, + corr_lvls=4, + num_flows=5, + channels=[48, 64, 72, 128], + skip_channels=48, + ): super(Model, self).__init__() self.radius = corr_radius self.corr_levels = corr_lvls self.num_flows = num_flows - self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.feat_encoder = BasicEncoder( + output_dim=128, norm_fn="instance", dropout=0.0 + ) self.encoder = Encoder([48, 64, 72, 128], large=True) - + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) @@ -43,45 +34,59 @@ class Model(nn.Module): self.update4 = self._get_updateblock(72, None) self.update3 = self._get_updateblock(64, 2.0) self.update2 = self._get_updateblock(48, 4.0) - + self.comb_block = nn.Sequential( - nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), - nn.PReLU(6*self.num_flows), - nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), ) def _get_updateblock(self, cdim, scale_factor=None): - return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48, - corr_dim=256, corr_dim2=160, fc_dim=124, - scale_factor=scale_factor, corr_levels=self.corr_levels, - radius=self.radius) + return BasicUpdateBlock( + cdim=cdim, + hidden_dim=128, + flow_dim=48, + corr_dim=256, + corr_dim2=160, + fc_dim=124, + scale_factor=scale_factor, + corr_levels=self.corr_levels, + radius=self.radius, + ) def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # based on linear assumption - t1_scale = 1. / embt - t0_scale = 1. / (1. - embt) + t1_scale = 1.0 / embt + t0_scale = 1.0 / (1.0 - embt) if downsample != 1: inv = 1 / downsample flow0 = inv * resize(flow0, scale_factor=inv) flow1 = inv * resize(flow1, scale_factor=inv) - - corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) corr = torch.cat([corr0, corr1], dim=1) flow = torch.cat([flow0, flow1], dim=1) return corr, flow - + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + mean_ = ( + torch.cat([img0, img1], 2) + .mean(1, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) img0 = img0 - mean_ img1 = img1 - mean_ img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 b, _, h, w = img0_.shape coord = coords_grid(b, h // 8, w // 8, img0.device) - - fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] - corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock( + fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels + ) # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] @@ -90,9 +95,9 @@ class Model(nn.Module): ######################################### the 4th decoder ######################################### up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) - corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, - up_flow0_4, up_flow1_4, - embt, downsample=1) + corr_4, flow_4 = self._corr_scale_lookup( + corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1 + ) # residue update with lookup corr delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) @@ -102,10 +107,12 @@ class Model(nn.Module): ft_3_ = ft_3_ + delta_ft_3_ ######################################### the 3rd decoder ######################################### - up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) - corr_3, flow_3 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_3, up_flow1_3, - embt, downsample=2) + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3( + ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4 + ) + corr_3, flow_3 = self._corr_scale_lookup( + corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2 + ) # residue update with lookup corr delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) @@ -115,11 +122,13 @@ class Model(nn.Module): ft_2_ = ft_2_ + delta_ft_2_ ######################################### the 2nd decoder ######################################### - up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) - corr_2, flow_2 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_2, up_flow1_2, - embt, downsample=4) - + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2( + ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3 + ) + corr_2, flow_2 = self._corr_scale_lookup( + corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4 + ) + # residue update with lookup corr delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) @@ -128,28 +137,36 @@ class Model(nn.Module): ft_1_ = ft_1_ + delta_ft_1_ ######################################### the 1st decoder ######################################### - up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) - - if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - mask = resize(mask, scale_factor=(1.0/scale_factor)) - img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1( + ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2 + ) - # Merge multiple predictions - imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, - mask, img_res, mean_) + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + mask = resize(mask, scale_factor=(1.0 / scale_factor)) + img_res = resize(img_res, scale_factor=(1.0 / scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine( + self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_ + ) imgt_pred = torch.clamp(imgt_pred, 0, 1) if eval: - return { 'imgt_pred': imgt_pred, } + return { + "imgt_pred": imgt_pred, + } else: up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], + "imgt_pred": imgt_pred, + "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + "ft_pred": [ft_1_, ft_2_, ft_3_], } - \ No newline at end of file diff --git a/networks/AMT-S.py b/src/networks/AMT-S.py similarity index 55% rename from networks/AMT-S.py rename to src/networks/AMT-S.py index a7155bb..133b14d 100755 --- a/networks/AMT-S.py +++ b/src/networks/AMT-S.py @@ -1,31 +1,20 @@ import torch import torch.nn as nn -from networks.blocks.raft import ( - coords_grid, - SmallUpdateBlock, BidirCorrBlock -) -from networks.blocks.feat_enc import ( - SmallEncoder -) -from networks.blocks.ifrnet import ( - resize, - Encoder, - InitDecoder, - IntermediateDecoder -) -from networks.blocks.multi_flow import ( - multi_flow_combine, - MultiFlowDecoder -) +from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock +from src.networks.blocks.feat_enc import SmallEncoder +from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder +from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder class Model(nn.Module): - def __init__(self, - corr_radius=3, - corr_lvls=4, - num_flows=3, - channels=[20, 32, 44, 56], - skip_channels=20): + def __init__( + self, + corr_radius=3, + corr_lvls=4, + num_flows=3, + channels=[20, 32, 44, 56], + skip_channels=20, + ): super(Model, self).__init__() self.radius = corr_radius self.corr_levels = corr_lvls @@ -33,7 +22,7 @@ class Model(nn.Module): self.channels = channels self.skip_channels = skip_channels - self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) + self.feat_encoder = SmallEncoder(output_dim=84, norm_fn="instance", dropout=0.0) self.encoder = Encoder(channels) self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) @@ -44,44 +33,58 @@ class Model(nn.Module): self.update4 = self._get_updateblock(44) self.update3 = self._get_updateblock(32, 2) self.update2 = self._get_updateblock(20, 4) - + self.comb_block = nn.Sequential( - nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1), - nn.PReLU(6*num_flows), - nn.Conv2d(6*num_flows, 3, 3, 1, 1), + nn.Conv2d(3 * num_flows, 6 * num_flows, 3, 1, 1), + nn.PReLU(6 * num_flows), + nn.Conv2d(6 * num_flows, 3, 3, 1, 1), ) def _get_updateblock(self, cdim, scale_factor=None): - return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64, - fc_dim=68, scale_factor=scale_factor, - corr_levels=self.corr_levels, radius=self.radius) + return SmallUpdateBlock( + cdim=cdim, + hidden_dim=76, + flow_dim=20, + corr_dim=64, + fc_dim=68, + scale_factor=scale_factor, + corr_levels=self.corr_levels, + radius=self.radius, + ) def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 # based on linear assumption - t1_scale = 1. / embt - t0_scale = 1. / (1. - embt) + t1_scale = 1.0 / embt + t0_scale = 1.0 / (1.0 - embt) if downsample != 1: inv = 1 / downsample flow0 = inv * resize(flow0, scale_factor=inv) flow1 = inv * resize(flow1, scale_factor=inv) - - corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) corr = torch.cat([corr0, corr1], dim=1) flow = torch.cat([flow0, flow1], dim=1) return corr, flow def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + mean_ = ( + torch.cat([img0, img1], 2) + .mean(1, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) img0 = img0 - mean_ img1 = img1 - mean_ img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 b, _, h, w = img0_.shape coord = coords_grid(b, h // 8, w // 8, img0.device) - - fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] - corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock( + fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels + ) # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] @@ -90,9 +93,9 @@ class Model(nn.Module): ######################################### the 4th decoder ######################################### up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) - corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, - up_flow0_4, up_flow1_4, - embt, downsample=1) + corr_4, flow_4 = self._corr_scale_lookup( + corr_fn, coord, up_flow0_4, up_flow1_4, embt, downsample=1 + ) # residue update with lookup corr delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) @@ -102,10 +105,12 @@ class Model(nn.Module): ft_3_ = ft_3_ + delta_ft_3_ ######################################### the 3rd decoder ######################################### - up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) - corr_3, flow_3 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_3, up_flow1_3, - embt, downsample=2) + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3( + ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4 + ) + corr_3, flow_3 = self._corr_scale_lookup( + corr_fn, coord, up_flow0_3, up_flow1_3, embt, downsample=2 + ) # residue update with lookup corr delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) @@ -115,11 +120,13 @@ class Model(nn.Module): ft_2_ = ft_2_ + delta_ft_2_ ######################################### the 2nd decoder ######################################### - up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) - corr_2, flow_2 = self._corr_scale_lookup(corr_fn, - coord, up_flow0_2, up_flow1_2, - embt, downsample=4) - + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2( + ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3 + ) + corr_2, flow_2 = self._corr_scale_lookup( + corr_fn, coord, up_flow0_2, up_flow1_2, embt, downsample=4 + ) + # residue update with lookup corr delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) @@ -128,27 +135,36 @@ class Model(nn.Module): ft_1_ = ft_1_ + delta_ft_1_ ######################################### the 1st decoder ######################################### - up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) - - if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - mask = resize(mask, scale_factor=(1.0/scale_factor)) - img_res = resize(img_res, scale_factor=(1.0/scale_factor)) - - # Merge multiple predictions - imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, - mask, img_res, mean_) + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1( + ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2 + ) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + mask = resize(mask, scale_factor=(1.0 / scale_factor)) + img_res = resize(img_res, scale_factor=(1.0 / scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine( + self.comb_block, img0, img1, up_flow0_1, up_flow1_1, mask, img_res, mean_ + ) imgt_pred = torch.clamp(imgt_pred, 0, 1) if eval: - return { 'imgt_pred': imgt_pred, } + return { + "imgt_pred": imgt_pred, + } else: up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], + "imgt_pred": imgt_pred, + "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + "ft_pred": [ft_1_, ft_2_, ft_3_], } diff --git a/networks/IFRNet.py b/src/networks/IFRNet.py similarity index 69% rename from networks/IFRNet.py rename to src/networks/IFRNet.py index 2f02f35..23a27fd 100755 --- a/networks/IFRNet.py +++ b/src/networks/IFRNet.py @@ -1,32 +1,25 @@ import torch import torch.nn as nn from src.utils.flow_utils import warp -from networks.blocks.ifrnet import ( - convrelu, resize, - ResBlock, -) +from src.networks.blocks.ifrnet import convrelu, resize, ResBlock class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.pyramid1 = nn.Sequential( - convrelu(3, 32, 3, 2, 1), - convrelu(32, 32, 3, 1, 1) + convrelu(3, 32, 3, 2, 1), convrelu(32, 32, 3, 1, 1) ) self.pyramid2 = nn.Sequential( - convrelu(32, 48, 3, 2, 1), - convrelu(48, 48, 3, 1, 1) + convrelu(32, 48, 3, 2, 1), convrelu(48, 48, 3, 1, 1) ) self.pyramid3 = nn.Sequential( - convrelu(48, 72, 3, 2, 1), - convrelu(72, 72, 3, 1, 1) + convrelu(48, 72, 3, 2, 1), convrelu(72, 72, 3, 1, 1) ) self.pyramid4 = nn.Sequential( - convrelu(72, 96, 3, 2, 1), - convrelu(96, 96, 3, 1, 1) + convrelu(72, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1) ) - + def forward(self, img): f1 = self.pyramid1(img) f2 = self.pyramid2(f1) @@ -39,11 +32,11 @@ class Decoder4(nn.Module): def __init__(self): super(Decoder4, self).__init__() self.convblock = nn.Sequential( - convrelu(192+1, 192), - ResBlock(192, 32), - nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True) + convrelu(192 + 1, 192), + ResBlock(192, 32), + nn.ConvTranspose2d(192, 76, 4, 2, 1, bias=True), ) - + def forward(self, f0, f1, embt): b, c, h, w = f0.shape embt = embt.repeat(1, 1, h, w) @@ -56,9 +49,9 @@ class Decoder3(nn.Module): def __init__(self): super(Decoder3, self).__init__() self.convblock = nn.Sequential( - convrelu(220, 216), - ResBlock(216, 32), - nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True) + convrelu(220, 216), + ResBlock(216, 32), + nn.ConvTranspose2d(216, 52, 4, 2, 1, bias=True), ) def forward(self, ft_, f0, f1, up_flow0, up_flow1): @@ -73,9 +66,9 @@ class Decoder2(nn.Module): def __init__(self): super(Decoder2, self).__init__() self.convblock = nn.Sequential( - convrelu(148, 144), - ResBlock(144, 32), - nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True) + convrelu(148, 144), + ResBlock(144, 32), + nn.ConvTranspose2d(144, 36, 4, 2, 1, bias=True), ) def forward(self, ft_, f0, f1, up_flow0, up_flow1): @@ -90,11 +83,11 @@ class Decoder1(nn.Module): def __init__(self): super(Decoder1, self).__init__() self.convblock = nn.Sequential( - convrelu(100, 96), - ResBlock(96, 32), - nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True) + convrelu(100, 96), + ResBlock(96, 32), + nn.ConvTranspose2d(96, 8, 4, 2, 1, bias=True), ) - + def forward(self, ft_, f0, f1, up_flow0, up_flow1): f0_warp = warp(f0, up_flow0) f1_warp = warp(f1, up_flow1) @@ -113,13 +106,18 @@ class Model(nn.Module): self.decoder1 = Decoder1() def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): - mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + mean_ = ( + torch.cat([img0, img1], 2) + .mean(1, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) img0 = img0 - mean_ img1 = img1 - mean_ - + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 - + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) @@ -143,13 +141,17 @@ class Model(nn.Module): up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) up_mask_1 = torch.sigmoid(out1[:, 4:5]) up_res_1 = out1[:, 5:] - + if scale_factor != 1.0: - up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) - up_mask_1 = resize(up_mask_1, scale_factor=(1.0/scale_factor)) - up_res_1 = resize(up_res_1, scale_factor=(1.0/scale_factor)) - + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor)) + up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor)) + img0_warp = warp(img0, up_flow0_1) img1_warp = warp(img1, up_flow1_1) imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ @@ -157,13 +159,15 @@ class Model(nn.Module): imgt_pred = torch.clamp(imgt_pred, 0, 1) if eval: - return { 'imgt_pred': imgt_pred, } + return { + "imgt_pred": imgt_pred, + } else: return { - 'imgt_pred': imgt_pred, - 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], - 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], - 'ft_pred': [ft_1_, ft_2_, ft_3_], - 'img0_warp': img0_warp, - 'img1_warp': img1_warp + "imgt_pred": imgt_pred, + "flow0_pred": [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + "flow1_pred": [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + "ft_pred": [ft_1_, ft_2_, ft_3_], + "img0_warp": img0_warp, + "img1_warp": img1_warp, } diff --git a/networks/__init__.py b/src/networks/__init__.py similarity index 100% rename from networks/__init__.py rename to src/networks/__init__.py diff --git a/networks/blocks/__init__.py b/src/networks/blocks/__init__.py similarity index 100% rename from networks/blocks/__init__.py rename to src/networks/blocks/__init__.py diff --git a/networks/blocks/feat_enc.py b/src/networks/blocks/feat_enc.py similarity index 100% rename from networks/blocks/feat_enc.py rename to src/networks/blocks/feat_enc.py diff --git a/networks/blocks/ifrnet.py b/src/networks/blocks/ifrnet.py similarity index 100% rename from networks/blocks/ifrnet.py rename to src/networks/blocks/ifrnet.py diff --git a/src/networks/blocks/multi_flow.py b/src/networks/blocks/multi_flow.py new file mode 100755 index 0000000..21167b2 --- /dev/null +++ b/src/networks/blocks/multi_flow.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from src.utils.flow_utils import warp +from src.networks.blocks.ifrnet import convrelu, resize, ResBlock + + +def multi_flow_combine( + comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None +): + """ + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + """ + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = ( + mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w) + if mask is not None + else None + ) + img_res = ( + img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w) + if img_res is not None + else 0 + ) + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = ( + torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1) + if mean is not None + else 0 + ) + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch * 3 + 4, in_ch * 3), + ResBlock(in_ch * 3, skip_ch), + nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split( + out, [2 * n, 2 * n, n, 3 * n], 1 + ) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0).repeat( + 1, self.num_flows, 1, 1 + ) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0).repeat( + 1, self.num_flows, 1, 1 + ) + + return flow0, flow1, mask, img_res diff --git a/networks/blocks/raft.py b/src/networks/blocks/raft.py similarity index 100% rename from networks/blocks/raft.py rename to src/networks/blocks/raft.py diff --git a/src/pretrained/amt-l.pth b/src/pretrained/amt-l.pth new file mode 100644 index 0000000..d67c49a Binary files /dev/null and b/src/pretrained/amt-l.pth differ diff --git a/src/pretrained/amt-s.pth b/src/pretrained/amt-s.pth new file mode 100644 index 0000000..dbfe53e Binary files /dev/null and b/src/pretrained/amt-s.pth differ diff --git a/src/utils/build.py b/src/utils/build.py index f3f91e9..61dd4ac 100644 --- a/src/utils/build.py +++ b/src/utils/build.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING import importlib +from typing import TYPE_CHECKING + if TYPE_CHECKING: from omegaconf import DictConfig diff --git a/src/utils/fs.py b/src/utils/fs.py new file mode 100644 index 0000000..d6e3016 --- /dev/null +++ b/src/utils/fs.py @@ -0,0 +1,53 @@ +from pathlib import Path + + +class FileSystem: + SOURCE_PATH = "source" + OUTPUT_PATH = "output" + FRAMES_PATH = "frames" + INTERPOLATED_PATH = "interpolated" + MOVED_PATH = "moved" + VIDEO_PART_PATH = "video_parts" + + def __init__(self, base_path: Path): + self.base_path = base_path + self.base_path.mkdir(parents=True, exist_ok=True) + + def create_directory(self, dir_name: str) -> Path: + """Creates a directory under the base path.""" + dir_path = self.base_path / dir_name + dir_path.mkdir(parents=True, exist_ok=True) + return dir_path + + def clear_directory(self, dir_path: Path): + """Clears all files in the specified directory.""" + for item in dir_path.iterdir(): + if item.is_file(): + item.unlink() + elif item.is_dir(): + self.clear_directory(item) + item.rmdir() + + @property + def source_path(self) -> Path: + return self.create_directory(self.SOURCE_PATH) + + @property + def output_path(self) -> Path: + return self.create_directory(self.OUTPUT_PATH) + + @property + def frames_path(self) -> Path: + return self.create_directory(self.FRAMES_PATH) + + @property + def interpolated_path(self) -> Path: + return self.create_directory(self.INTERPOLATED_PATH) + + @property + def moved_path(self) -> Path: + return self.create_directory(self.MOVED_PATH) + + @property + def video_part_path(self) -> Path: + return self.create_directory(self.VIDEO_PART_PATH) \ No newline at end of file diff --git a/src/utils/utils.py b/src/utils/utils.py deleted file mode 100644 index 51346b8..0000000 --- a/src/utils/utils.py +++ /dev/null @@ -1,199 +0,0 @@ -import re -import sys -from pathlib import Path - -import numpy as np -from imageio import imread, imwrite - - -def read(file: Path) -> np.ndarray: - readers = { - ".float3": readFloat, - ".flo": readFlow, - ".ppm": readImage, - ".pgm": readImage, - ".png": readImage, - ".jpg": readImage, - ".pfm": lambda f: readPFM(f)[0], - } - func = readers.get(file.suffix.lower()) - if func is None: - raise Exception("don't know how to read %s" % file) - return func(file) - - -def write(file: Path, data: np.ndarray) -> None: - writers = { - ".float3": writeFloat, - ".flo": writeFlow, - ".ppm": writeImage, - ".pgm": writeImage, - ".png": writeImage, - ".jpg": writeImage, - ".pfm": writePFM, - } - func = writers.get(file.suffix.lower()) - if func is None: - raise Exception("don't know how to write %s" % file) - return func(file, data) - - -def readPFM(file: Path): - data = open(file, "rb") - - color = None - width = None - height = None - scale = None - endian = None - - header = data.readline().rstrip() - if header.decode("ascii") == "PF": - color = True - elif header.decode("ascii") == "Pf": - color = False - else: - raise Exception("Not a PFM file.") - - dim_match = re.match(r"^(\d+)\s(\d+)\s$", data.readline().decode("ascii")) - if dim_match: - width, height = list(map(int, dim_match.groups())) - else: - raise Exception("Malformed PFM header.") - - scale = float(data.readline().decode("ascii").rstrip()) - if scale < 0: - endian = "<" - scale = -scale - else: - endian = ">" - - result = np.fromfile(data, endian + "f") - shape = (height, width, 3) if color else (height, width) - - result = np.reshape(result, shape) - result = np.flipud(result) - return result, scale - - -def writePFM(file: Path, image: np.ndarray, scale=1): - data = open(file, "wb") - - color = None - - if image.dtype.name != "float32": - raise Exception("Image dtype must be float32.") - - image = np.flipud(image) - - if len(image.shape) == 3 and image.shape[2] == 3: - color = True - elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: - color = False - else: - raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") - - data.write("PF\n" if color else "Pf\n".encode()) # type: ignore - data.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) - - endian = image.dtype.byteorder - - if endian == "<" or endian == "=" and sys.byteorder == "little": - scale = -scale - - data.write("%f\n".encode() % scale) - - image.tofile(data) - - -def readFlow(file: Path): - if file.suffix.lower() == ".pfm": - return readPFM(file)[0][:, :, 0:2] - - f = open(file, "rb") - - header = f.read(4) - if header.decode("utf-8") != "PIEH": - raise Exception("Flow file header does not contain PIEH") - - width = np.fromfile(f, np.int32, 1).squeeze() - height = np.fromfile(f, np.int32, 1).squeeze() - - flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) - - return flow.astype(np.float32) - - -def readImage(file: Path): - if file.suffix.lower() == ".pfm": - data = readPFM(file)[0] - if len(data.shape) == 3: - return data[:, :, 0:3] - else: - return data - return imread(file) - - -def writeImage(file: Path, data: np.ndarray): - if file.suffix.lower() == ".pfm": - return writePFM(file, data, 1) - return imwrite(file, data) - - -def writeFlow(file: Path, flow: np.ndarray): - f = open(file, "wb") - f.write("PIEH".encode("utf-8")) - np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) - flow = flow.astype(np.float32) - flow.tofile(f) - - -def readFloat(file: Path): - f = open(file, "rb") - - if (f.readline().decode("utf-8")) != "float\n": - raise Exception("float file %s did not contain keyword" % file) - - dim = int(f.readline()) - - dims = [] - count = 1 - for _ in range(0, dim): - d = int(f.readline()) - dims.append(d) - count *= d - - dims = list(reversed(dims)) - - data = np.fromfile(f, np.float32, count).reshape(dims) - if dim > 2: - data = np.transpose(data, (2, 1, 0)) - data = np.transpose(data, (1, 0, 2)) - - return data - - -def writeFloat(file: Path, data: np.ndarray): - f = open(file, "wb") - - dim = len(data.shape) - if dim > 3: - raise Exception("bad float file dimension: %d" % dim) - - f.write(("float\n").encode("ascii")) - f.write(("%d\n" % dim).encode("ascii")) - - if dim == 1: - f.write(("%d\n" % data.shape[0]).encode("ascii")) - else: - f.write(("%d\n" % data.shape[1]).encode("ascii")) - f.write(("%d\n" % data.shape[0]).encode("ascii")) - for i in range(2, dim): - f.write(("%d\n" % data.shape[i]).encode("ascii")) - - data = data.astype(np.float32) - if dim == 2: - data.tofile(f) - - else: - np.transpose(data, (2, 0, 1)).tofile(f) diff --git a/src/utils/video.py b/src/utils/video.py new file mode 100644 index 0000000..c9ab60d --- /dev/null +++ b/src/utils/video.py @@ -0,0 +1,105 @@ +import os +import logging +import subprocess +from pathlib import Path + +import cv2 +from typing import Generator + + +class VideoMaker: + def images_to_video( + self, + images_path: Path, + output_path: Path, + fps: float, + image_numerator: str = "img_%08d.png", + ): + """Converts a sequence of images to a video using ffmpeg.""" + cmd = f"ffmpeg -framerate {fps} -i {images_path / image_numerator} -c:v libx264 -pix_fmt yuv420p {output_path}" + logging.info(f"Running command: {cmd}") + result = self.run_command(cmd) + if result != 0: + logging.error(f"Failed to create video. Command returned {result}") + + def concatenate_videos( + self, + videos_path: Path, + output_path: Path, + video_numerator: str = "video_%08d.mp4", + ): + """Concatenates a sequence of videos using ffmpeg.""" + + videos = sorted(videos_path.glob("*.mp4")) + file = "file.txt" + with open(file, "w") as f: + for video in videos: + f.write(f"file '{video}'\n") + cmd = f"ffmpeg -f concat -safe 0 -i {file} -c copy {output_path}" + logging.info(f"Running command: {cmd}") + result = self.run_command(cmd) + if result != 0: + logging.error(f"Failed to concatenate videos. Command returned {result}") + os.remove(file) + + def get_fps(self, video_path: Path) -> float: + """Gets the frames per second (FPS) of a video.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + logging.debug(f"FPS of video {video_path}: {fps}") + return fps + + def get_video_duration(self, video_path: Path) -> float: + """Gets the duration of a video in seconds.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + cap.release() + duration = frame_count / fps + logging.debug(f"Duration of video {video_path}: {duration:.2f} seconds") + return duration + + def run_command(self, cmd: str) -> int: + try: + subprocess.run(cmd, shell=True, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + return 0 + except subprocess.CalledProcessError as e: + logging.error(f"Command failed with error: {e}") + return e.returncode + + def video_to_frames_generator( + self, video_path: Path, output_dir: Path, chunk_seconds: int = 10 + ) -> Generator[tuple[Path, ...], None, None]: + """Extracts frames from a video and saves them to disk, yielding paths to the saved frames.""" + + cap = cv2.VideoCapture(str(video_path)) + + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) + frames_per_chunk = int(fps * chunk_seconds) + + frame_index = 0 + + while True: + paths = [] + + for _ in range(frames_per_chunk): + ret, frame = cap.read() + if not ret: + cap.release() + return + + frame_path = output_dir / f"img_{frame_index:08d}.png" + cv2.imwrite(str(frame_path), frame) + + paths.append(frame_path) + frame_index += 1 + + yield tuple(paths)