From 2d67b72128082c5999b1cfa0da46ac0a75dbd756 Mon Sep 17 00:00:00 2001 From: Viner Abubakirov Date: Sat, 4 Apr 2026 11:57:41 +0500 Subject: [PATCH] =?UTF-8?q?=D0=9F=D0=B5=D1=80=D0=B5=D0=B2=D0=B5=D0=BB=20?= =?UTF-8?q?=D0=B8=D0=BC=D0=BF=D0=BE=D1=80=D1=82=D1=8B=20=D0=BC=D0=BE=D0=B4?= =?UTF-8?q?=D1=83=D0=BB=D0=B5=D0=B9=20=D0=B2=20=D0=BE=D1=82=D0=BD=D0=BE?= =?UTF-8?q?=D1=81=D0=B8=D1=82=D0=B5=D0=BB=D1=8C=D0=BD=D1=8B=D0=B5=20=D0=BF?= =?UTF-8?q?=D1=83=D1=82=D0=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 194 +-------------------------- src/config/AMT-G.yaml | 2 +- src/config/AMT-L.yaml | 2 +- src/config/AMT-S.yaml | 2 +- src/interpolator.py | 6 +- src/networks/{AMT-G.py => AMT_G.py} | 8 +- src/networks/{AMT-L.py => AMT_L.py} | 8 +- src/networks/{AMT-S.py => AMT_S.py} | 8 +- src/networks/IFRNet.py | 4 +- src/networks/blocks/ifrnet.py | 2 +- src/networks/blocks/multi_flow.py | 4 +- src/runner.py | 196 ++++++++++++++++++++++++++++ src/utils/build.py | 15 ++- 13 files changed, 229 insertions(+), 222 deletions(-) rename src/networks/{AMT-G.py => AMT_G.py} (96%) rename src/networks/{AMT-L.py => AMT_L.py} (95%) rename src/networks/{AMT-S.py => AMT_S.py} (95%) create mode 100644 src/runner.py diff --git a/main.py b/main.py index 170171b..69e3c04 100644 --- a/main.py +++ b/main.py @@ -1,199 +1,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING - -from cv2 import imwrite -import tqdm - +from src.runner import runner from src.config import presets -from src.utils.fs import FileSystem -from src.utils.video import VideoMaker -from src.interpolator import ( - ImageInterpolator, - Anchor, - get_device, - get_vram_available, - ModelRunner, -) - - -if TYPE_CHECKING: - import torch - import numpy as np - - -def performing_warning_message(device: "torch.device"): - if device.type in ("cpu", "mps"): - if device.type == "mps": - logging.warning( - "Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance." - ) - else: - logging.warning( - "Running on CPU may be very slow. Consider using a GPU for better performance." - ) - elif device.type == "cuda": - pass - else: - raise Exception(f"Unsupported device type: {device.type}") - - -def init_fs(base_path: Path) -> FileSystem: - fs = FileSystem(base_path) - fs.clear_directory(fs.frames_path) - fs.clear_directory(fs.interpolated_path) - fs.clear_directory(fs.moved_path) - fs.clear_directory(fs.video_part_path) - return fs - - -def init_video_maker() -> VideoMaker: - return VideoMaker() - - -def init_device() -> "torch.device": - device = get_device() - performing_warning_message(device) - vram_available = get_vram_available(device) - logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB") - return device - - -def init_anchor(device: "torch.device") -> Anchor: - if device.type in ("cpu", "mps"): - return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) - elif device.type == "cuda": - return Anchor( - resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 - ) - else: - raise Exception(f"Unsupported device type: {device.type}") - - -def init_model_runner( - config: Path, checkpoint_path: Path, device: "torch.device" -) -> ModelRunner: - return ModelRunner(config, checkpoint_path, device) - - -def init_interpolator( - model_runner: ModelRunner, device: "torch.device" -) -> ImageInterpolator: - anchor = init_anchor(device) - return ImageInterpolator(device, anchor, model_runner) - - -class InterpolationPipeline: - def __init__( - self, - config: Path, - checkpoint_path: Path, - base_path: Path, - ): - self.fs = init_fs(base_path) - self.video_maker = init_video_maker() - self.device = init_device() - self.model_runner = init_model_runner(config, checkpoint_path, self.device) - self.interpolator = init_interpolator(self.model_runner, self.device) - - def run(self, video_path: Path, output_video: str): - prev_frames = tuple() - interpolated_frames: list["np.ndarray"] = [] - part = 0 - chunk_seconds = 10 - length = self.video_maker.get_video_duration(video_path) - last_part_seconds = 1 if length % chunk_seconds else 0 - total_parts = int(length // chunk_seconds) + last_part_seconds - fps = self.video_maker.get_fps(video_path) - logging.info(f"Video FPS: {fps}") - fps *= 2 # Doubling FPS - width, height = self.video_maker.get_size(video_path) - for frames in self.video_maker.video_to_frames_generator( - video_path, self.fs.frames_path, chunk_seconds - ): - logging.info(f"Processing frames: {len(frames)}") - if prev_frames: - img1 = prev_frames[-1] - img2 = frames[0] - img1_2 = self.interpolator.interpolate(img1, img2) - interpolated_frames.append(img1_2) - generator = self._frame_generator(prev_frames, interpolated_frames) - part_path = self.fs.video_part_path / f"video_{part:08d}.mp4" - self.video_maker.images_to_video_pipeline( - generator, part_path, width, height, fps - ) - interpolated_frames = [] - logging.info(f"Finished processing part {part:08d}") - part += 1 - for i in tqdm.tqdm( - range(len(frames) - 1), - desc=f"Processing video frames {part + 1} / {total_parts}", - ): - img1 = frames[i] - img2 = frames[i + 1] - img1_2 = self.interpolator.interpolate(img1, img2) - interpolated_frames.append(img1_2) - prev_frames = frames - - generator = self._frame_generator(prev_frames, interpolated_frames) - part_path = self.fs.video_part_path / f"video_{part:08d}.mp4" - self.video_maker.images_to_video_pipeline( - generator, part_path, width, height, fps - ) - logging.info(f"Finished processing part {part:08d}") - self._merge_video_parts(self.fs.output_path / output_video) - logging.info( - f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" - ) - - def _save_images( - self, - source: tuple["np.ndarray", ...], - interpolated: list["np.ndarray"], - ): - logging.info("Saving images...") - self.fs.clear_directory(self.fs.moved_path) - index = 0 - for i, frame in enumerate(source): - name = self.fs.moved_path / f"img_{index:08d}.png" - index += 1 - imwrite(name, frame) - if i < len(interpolated): - name = self.fs.moved_path / f"img_{index:08d}.png" - index += 1 - imwrite(name, interpolated[i]) - logging.info("Success...") - - def _merge_frames_to_video(self, output_video: Path, fps: float): - self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) - - def _merge_video_parts(self, output_video: Path): - self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) - self.fs.clear_directory(self.fs.video_part_path) - - def _frame_generator( - self, - source: tuple["np.ndarray", ...], - interpolated: list["np.ndarray"], - ): - for i, frame in enumerate(source): - yield frame - if i < len(interpolated): - yield interpolated[i] - - -def runner( - base_path: Path, - video_path: Path, - output_video: str, - preset: presets.Preset = presets.LARGE, -): - pipeline = InterpolationPipeline( - config=preset.config, - checkpoint_path=preset.checkpoint, - base_path=base_path, - ) - pipeline.run(video_path, output_video) def main(): diff --git a/src/config/AMT-G.yaml b/src/config/AMT-G.yaml index 4570e2b..581f862 100755 --- a/src/config/AMT-G.yaml +++ b/src/config/AMT-G.yaml @@ -10,7 +10,7 @@ save_dir: work_dir eval_interval: 1 network: - name: src.networks.AMT-G.Model + name: AMT-G.Model params: corr_radius: 3 corr_lvls: 4 diff --git a/src/config/AMT-L.yaml b/src/config/AMT-L.yaml index 19d70ca..5b14366 100644 --- a/src/config/AMT-L.yaml +++ b/src/config/AMT-L.yaml @@ -10,7 +10,7 @@ save_dir: work_dir eval_interval: 1 network: - name: src.networks.AMT-L.Model + name: AMT-L.Model params: corr_radius: 3 corr_lvls: 4 diff --git a/src/config/AMT-S.yaml b/src/config/AMT-S.yaml index d763a5a..a049055 100755 --- a/src/config/AMT-S.yaml +++ b/src/config/AMT-S.yaml @@ -10,7 +10,7 @@ save_dir: work_dir eval_interval: 1 network: - name: src.networks.AMT-S.Model + name: AMT-S.Model params: corr_radius: 3 corr_lvls: 4 diff --git a/src/interpolator.py b/src/interpolator.py index 49d0b94..4ed71b6 100644 --- a/src/interpolator.py +++ b/src/interpolator.py @@ -5,9 +5,9 @@ import torch import numpy as np from omegaconf import OmegaConf, DictConfig -from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img -from src.utils.build import build_from_cfg -from src.utils.padder import InputPadder +from .utils.torch import img2tensor, check_dim_and_resize, tensor2img +from .utils.build import build_from_cfg +from .utils.padder import InputPadder class Anchor: diff --git a/src/networks/AMT-G.py b/src/networks/AMT_G.py similarity index 96% rename from src/networks/AMT-G.py rename to src/networks/AMT_G.py index 3f3de9e..954ff10 100755 --- a/src/networks/AMT-G.py +++ b/src/networks/AMT_G.py @@ -2,10 +2,10 @@ from typing import Optional import torch import torch.nn as nn -from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock -from src.networks.blocks.feat_enc import LargeEncoder -from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder -from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder +from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock +from .blocks.feat_enc import LargeEncoder +from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder +from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder class Model(nn.Module): diff --git a/src/networks/AMT-L.py b/src/networks/AMT_L.py similarity index 95% rename from src/networks/AMT-L.py rename to src/networks/AMT_L.py index 6243a4d..4cf6227 100755 --- a/src/networks/AMT-L.py +++ b/src/networks/AMT_L.py @@ -1,10 +1,10 @@ import torch import torch.nn as nn -from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock -from src.networks.blocks.feat_enc import BasicEncoder -from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder +from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock +from .blocks.feat_enc import BasicEncoder +from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder -from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder +from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder class Model(nn.Module): diff --git a/src/networks/AMT-S.py b/src/networks/AMT_S.py similarity index 95% rename from src/networks/AMT-S.py rename to src/networks/AMT_S.py index 133b14d..64b9e60 100755 --- a/src/networks/AMT-S.py +++ b/src/networks/AMT_S.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock -from src.networks.blocks.feat_enc import SmallEncoder -from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder -from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder +from .blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock +from .blocks.feat_enc import SmallEncoder +from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder +from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder class Model(nn.Module): diff --git a/src/networks/IFRNet.py b/src/networks/IFRNet.py index 23a27fd..21b420d 100755 --- a/src/networks/IFRNet.py +++ b/src/networks/IFRNet.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from src.utils.flow_utils import warp -from src.networks.blocks.ifrnet import convrelu, resize, ResBlock +from ..utils.flow_utils import warp +from .blocks.ifrnet import convrelu, resize, ResBlock class Encoder(nn.Module): diff --git a/src/networks/blocks/ifrnet.py b/src/networks/blocks/ifrnet.py index c2866ee..848e9df 100755 --- a/src/networks/blocks/ifrnet.py +++ b/src/networks/blocks/ifrnet.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from src.utils.flow_utils import warp +from ...utils.flow_utils import warp def resize(x, scale_factor): diff --git a/src/networks/blocks/multi_flow.py b/src/networks/blocks/multi_flow.py index 21167b2..32d56aa 100755 --- a/src/networks/blocks/multi_flow.py +++ b/src/networks/blocks/multi_flow.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from src.utils.flow_utils import warp -from src.networks.blocks.ifrnet import convrelu, resize, ResBlock +from ...utils.flow_utils import warp +from .ifrnet import convrelu, resize, ResBlock def multi_flow_combine( diff --git a/src/runner.py b/src/runner.py new file mode 100644 index 0000000..98dc63a --- /dev/null +++ b/src/runner.py @@ -0,0 +1,196 @@ +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from cv2 import imwrite +import tqdm + +from .config import presets +from .utils.fs import FileSystem +from .utils.video import VideoMaker +from .interpolator import ( + ImageInterpolator, + Anchor, + get_device, + get_vram_available, + ModelRunner, +) + + +if TYPE_CHECKING: + import torch + import numpy as np + + +def performing_warning_message(device: "torch.device"): + if device.type in ("cpu", "mps"): + if device.type == "mps": + logging.warning( + "Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance." + ) + else: + logging.warning( + "Running on CPU may be very slow. Consider using a GPU for better performance." + ) + elif device.type == "cuda": + pass + else: + raise Exception(f"Unsupported device type: {device.type}") + + +def init_fs(base_path: Path) -> FileSystem: + fs = FileSystem(base_path) + fs.clear_directory(fs.frames_path) + fs.clear_directory(fs.interpolated_path) + fs.clear_directory(fs.moved_path) + fs.clear_directory(fs.video_part_path) + return fs + + +def init_video_maker() -> VideoMaker: + return VideoMaker() + + +def init_device() -> "torch.device": + device = get_device() + performing_warning_message(device) + vram_available = get_vram_available(device) + logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB") + return device + + +def init_anchor(device: "torch.device") -> Anchor: + if device.type in ("cpu", "mps"): + return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0) + elif device.type == "cuda": + return Anchor( + resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2 + ) + else: + raise Exception(f"Unsupported device type: {device.type}") + + +def init_model_runner( + config: Path, checkpoint_path: Path, device: "torch.device" +) -> ModelRunner: + return ModelRunner(config, checkpoint_path, device) + + +def init_interpolator( + model_runner: ModelRunner, device: "torch.device" +) -> ImageInterpolator: + anchor = init_anchor(device) + return ImageInterpolator(device, anchor, model_runner) + + +class InterpolationPipeline: + def __init__( + self, + config: Path, + checkpoint_path: Path, + base_path: Path, + ): + self.fs = init_fs(base_path) + self.video_maker = init_video_maker() + self.device = init_device() + self.model_runner = init_model_runner(config, checkpoint_path, self.device) + self.interpolator = init_interpolator(self.model_runner, self.device) + + def run(self, video_path: Path, output_video: str): + prev_frames: tuple["np.ndarray", ...] = tuple() + interpolated_frames: list["np.ndarray"] = [] + part = 0 + chunk_seconds = 10 + length = self.video_maker.get_video_duration(video_path) + last_part_seconds = 1 if length % chunk_seconds else 0 + total_parts = int(length // chunk_seconds) + last_part_seconds + fps = self.video_maker.get_fps(video_path) + logging.info(f"Video FPS: {fps}") + fps *= 2 # Doubling FPS + width, height = self.video_maker.get_size(video_path) + for frames in self.video_maker.video_to_frames_generator( + video_path, self.fs.frames_path, chunk_seconds + ): + logging.info(f"Processing frames: {len(frames)}") + if prev_frames: + img1 = prev_frames[-1] + img2 = frames[0] + img1_2 = self.interpolator.interpolate(img1, img2) + interpolated_frames.append(img1_2) + generator = self._frame_generator(prev_frames, interpolated_frames) + part_path = self.fs.video_part_path / f"video_{part:08d}.mp4" + self.video_maker.images_to_video_pipeline( + generator, part_path, width, height, fps + ) + interpolated_frames = [] + logging.info(f"Finished processing part {part:08d}") + part += 1 + for i in tqdm.tqdm( + range(len(frames) - 1), + desc=f"Processing video frames {part + 1} / {total_parts}", + ): + img1 = frames[i] + img2 = frames[i + 1] + img1_2 = self.interpolator.interpolate(img1, img2) + interpolated_frames.append(img1_2) + prev_frames = frames + + generator = self._frame_generator(prev_frames, interpolated_frames) + part_path = self.fs.video_part_path / f"video_{part:08d}.mp4" + self.video_maker.images_to_video_pipeline( + generator, part_path, width, height, fps + ) + logging.info(f"Finished processing part {part:08d}") + self._merge_video_parts(self.fs.output_path / output_video) + logging.info( + f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}" + ) + + def _save_images( + self, + source: tuple["np.ndarray", ...], + interpolated: list["np.ndarray"], + ): + logging.info("Saving images...") + self.fs.clear_directory(self.fs.moved_path) + index = 0 + for i, frame in enumerate(source): + name = self.fs.moved_path / f"img_{index:08d}.png" + index += 1 + imwrite(name, frame) + if i < len(interpolated): + name = self.fs.moved_path / f"img_{index:08d}.png" + index += 1 + imwrite(name, interpolated[i]) + logging.info("Success...") + + def _merge_frames_to_video(self, output_video: Path, fps: float): + self.video_maker.images_to_video(self.fs.moved_path, output_video, fps) + + def _merge_video_parts(self, output_video: Path): + self.video_maker.concatenate_videos(self.fs.video_part_path, output_video) + self.fs.clear_directory(self.fs.video_part_path) + + def _frame_generator( + self, + source: tuple["np.ndarray", ...], + interpolated: list["np.ndarray"], + ): + for i, frame in enumerate(source): + yield frame + if i < len(interpolated): + yield interpolated[i] + + +def runner( + base_path: Path, + video_path: Path, + output_video: str, + preset: presets.Preset = presets.LARGE, +): + pipeline = InterpolationPipeline( + config=preset.config, + checkpoint_path=preset.checkpoint, + base_path=base_path, + ) + pipeline.run(video_path, output_video) diff --git a/src/utils/build.py b/src/utils/build.py index 61dd4ac..2e66ebc 100644 --- a/src/utils/build.py +++ b/src/utils/build.py @@ -1,16 +1,19 @@ -import importlib from typing import TYPE_CHECKING +from ..networks import AMT_G, AMT_L, AMT_S if TYPE_CHECKING: from omegaconf import DictConfig -def base_build_fn(module: str, cls: str, params: dict): - return getattr(importlib.import_module(module, package=None), cls)(**params) - - def build_from_cfg(config: "DictConfig"): + packages = { + "AMT-G": AMT_G, + "AMT-L": AMT_L, + "AMT-S": AMT_S + } + module, cls = config["name"].rsplit(".", 1) params: dict = config.get("params", {}) - return base_build_fn(module, cls, params) + return getattr(packages[module], cls)(**params) +