diff --git a/main.py b/main.py index 1ab35c7..f333138 100644 --- a/main.py +++ b/main.py @@ -4,9 +4,10 @@ from typing import TYPE_CHECKING import tqdm +from src.config import presets from src.utils.fs import FileSystem from src.utils.video import VideoMaker -from interpolator import ( +from src.interpolator import ( ImageInterpolator, Anchor, get_device, @@ -104,7 +105,7 @@ class InterpolationPipeline: total_parts = int(length // chunk_seconds) + last_part_seconds fps = self.video_maker.get_fps(video_path) logging.info(f"Video FPS: {fps}") - fps *= 2 # Doubling FPS + fps *= 2 # Doubling FPS for frame_paths in self.video_maker.video_to_frames_generator( video_path, self.fs.frames_path, chunk_seconds ): @@ -125,7 +126,7 @@ class InterpolationPipeline: part += 1 for i in tqdm.tqdm( range(len(frame_paths) - 1), - desc=f"Processing video frames {part} / {total_parts}", + desc=f"Processing video frames {part + 1} / {total_parts}", ): img1 = frame_paths[i] img2 = frame_paths[i + 1] @@ -175,14 +176,16 @@ class InterpolationPipeline: ) -def main(): - config = Path("src/config/AMT-G.yaml") - checkpoint_path = Path("src/pretrained/amt-g.pth") +def main(preset: presets.Preset = presets.LARGE): base_path = Path("output") video_path = Path("example/video.mp4") output_video = "interpolated_video.mp4" - pipeline = InterpolationPipeline(config, checkpoint_path, base_path) + pipeline = InterpolationPipeline( + config=preset.config, + checkpoint_path=preset.checkpoint, + base_path=base_path, + ) pipeline.run(video_path, output_video) diff --git a/src/config/AMT-L.yaml b/src/config/AMT-L.yaml new file mode 100644 index 0000000..19d70ca --- /dev/null +++ b/src/config/AMT-L.yaml @@ -0,0 +1,62 @@ +exp_name: floloss1e-2_300epoch_bs24_lr2e-4 +seed: 2023 +epochs: 300 +distributed: true +lr: 2e-4 +lr_min: 2e-5 +weight_decay: 0.0 +resume_state: null +save_dir: work_dir +eval_interval: 1 + +network: + name: src.networks.AMT-L.Model + params: + corr_radius: 3 + corr_lvls: 4 + num_flows: 5 +data: + train: + name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset + params: + dataset_dir: data/vimeo_triplet + val: + name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset + params: + dataset_dir: data/vimeo_triplet + train_loader: + batch_size: 24 + num_workers: 12 + val_loader: + batch_size: 24 + num_workers: 3 + +logger: + use_wandb: true + resume_id: null + +losses: + - { + name: losses.loss.CharbonnierLoss, + nickname: l_rec, + params: { + loss_weight: 1.0, + keys: [imgt_pred, imgt] + } + } + - { + name: losses.loss.TernaryLoss, + nickname: l_ter, + params: { + loss_weight: 1.0, + keys: [imgt_pred, imgt] + } + } + - { + name: losses.loss.MultipleFlowLoss, + nickname: l_flo, + params: { + loss_weight: 0.002, + keys: [flow0_pred, flow1_pred, flow] + } + } \ No newline at end of file diff --git a/src/config/presets.py b/src/config/presets.py new file mode 100644 index 0000000..7687a1b --- /dev/null +++ b/src/config/presets.py @@ -0,0 +1,24 @@ +from pathlib import Path +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Preset: + config: Path + checkpoint: Path + + +SMALL = Preset( + config=Path("src/config/AMT-S.yaml"), + checkpoint=Path("src/pretrained/amt-s.pth"), +) + +LARGE = Preset( + config=Path("src/config/AMT-L.yaml"), + checkpoint=Path("src/pretrained/amt-l.pth"), +) + +GLOBAL = Preset( + config=Path("src/config/AMT-g.yaml"), + checkpoint=Path("src/pretrained/amt-g.pth"), +) diff --git a/interpolator.py b/src/interpolator.py similarity index 100% rename from interpolator.py rename to src/interpolator.py diff --git a/src/pretrained/amt-l.pth b/src/pretrained/amt-l.pth new file mode 100644 index 0000000..d67c49a Binary files /dev/null and b/src/pretrained/amt-l.pth differ diff --git a/src/pretrained/amt-s.pth b/src/pretrained/amt-s.pth new file mode 100644 index 0000000..dbfe53e Binary files /dev/null and b/src/pretrained/amt-s.pth differ