Перевел импорты модулей в относительные пути

This commit is contained in:
Viner Abubakirov
2026-04-04 11:57:41 +05:00
parent c91cf6b53a
commit 2d67b72128
13 changed files with 229 additions and 222 deletions

194
main.py
View File

@@ -1,199 +1,7 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
from src.runner import runner
from src.config import presets
from src.utils.fs import FileSystem
from src.utils.video import VideoMaker
from src.interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
part += 1
for i in tqdm.tqdm(
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)
def main():

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: src.networks.AMT-G.Model
name: AMT-G.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: src.networks.AMT-L.Model
name: AMT-L.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: src.networks.AMT-S.Model
name: AMT-S.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -5,9 +5,9 @@ import torch
import numpy as np
from omegaconf import OmegaConf, DictConfig
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
from .utils.torch import img2tensor, check_dim_and_resize, tensor2img
from .utils.build import build_from_cfg
from .utils.padder import InputPadder
class Anchor:

View File

@@ -2,10 +2,10 @@ from typing import Optional
import torch
import torch.nn as nn
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import LargeEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from .blocks.feat_enc import LargeEncoder
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):

View File

@@ -1,10 +1,10 @@
import torch
import torch.nn as nn
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import BasicEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from .blocks.feat_enc import BasicEncoder
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):

View File

@@ -1,9 +1,9 @@
import torch
import torch.nn as nn
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import SmallEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from .blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
from .blocks.feat_enc import SmallEncoder
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
from ..utils.flow_utils import warp
from .blocks.ifrnet import convrelu, resize, ResBlock
class Encoder(nn.Module):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.utils.flow_utils import warp
from ...utils.flow_utils import warp
def resize(x, scale_factor):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
from ...utils.flow_utils import warp
from .ifrnet import convrelu, resize, ResBlock
def multi_flow_combine(

196
src/runner.py Normal file
View File

@@ -0,0 +1,196 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
from .config import presets
from .utils.fs import FileSystem
from .utils.video import VideoMaker
from .interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames: tuple["np.ndarray", ...] = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
part += 1
for i in tqdm.tqdm(
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)

View File

@@ -1,16 +1,19 @@
import importlib
from typing import TYPE_CHECKING
from ..networks import AMT_G, AMT_L, AMT_S
if TYPE_CHECKING:
from omegaconf import DictConfig
def base_build_fn(module: str, cls: str, params: dict):
return getattr(importlib.import_module(module, package=None), cls)(**params)
def build_from_cfg(config: "DictConfig"):
packages = {
"AMT-G": AMT_G,
"AMT-L": AMT_L,
"AMT-S": AMT_S
}
module, cls = config["name"].rsplit(".", 1)
params: dict = config.get("params", {})
return base_build_fn(module, cls, params)
return getattr(packages[module], cls)(**params)