1 Commits

Author SHA1 Message Date
Viner Abubakirov
0c871c2314 Refactor image tensor conversion and model inference in interpolator.py and torch.py 2026-04-13 18:56:00 +05:00
14 changed files with 243 additions and 262 deletions

196
main.py
View File

@@ -1,7 +1,199 @@
import logging
from pathlib import Path
from src.runner import run
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
from src.config import presets
from src.utils.fs import FileSystem
from src.utils.video import VideoMaker
from src.interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
part += 1
for i in tqdm.tqdm(
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def runner(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)
def main():
@@ -30,7 +222,7 @@ def main():
default="global",
)
args = parser.parse_args()
run(
runner(
base_path=Path(args.base_path),
video_path=Path(args.video_path),
output_video=args.output,

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: AMT-G.Model
name: src.networks.AMT-G.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: AMT-L.Model
name: src.networks.AMT-L.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -10,7 +10,7 @@ save_dir: work_dir
eval_interval: 1
network:
name: AMT-S.Model
name: src.networks.AMT-S.Model
params:
corr_radius: 3
corr_lvls: 4

View File

@@ -1,13 +1,14 @@
import logging
from pathlib import Path
from typing import Optional
import torch
import numpy as np
from omegaconf import OmegaConf, DictConfig
from .utils.torch import img2tensor, check_dim_and_resize, tensor2img
from .utils.build import build_from_cfg
from .utils.padder import InputPadder
from src.utils.torch import img2tensor, tensor2img
from src.utils.build import build_from_cfg
from src.utils.padder import InputPadder
class Anchor:
@@ -39,7 +40,7 @@ class ModelRunner:
model.load_state_dict(checkpoint["state_dict"])
model = model.to(get_device())
model.eval()
self.model = model
self.model = torch.compile(model)
def get_vram_available(device: torch.device) -> int:
@@ -91,35 +92,19 @@ class ImageInterpolator:
output_path (Path): Path to save the interpolated image (only png and jpg formats are supported)
"""
logging.debug(f"Reading images: {image1} and {image2}")
tensor1 = img2tensor(image1).to(self.device)
tensor2 = img2tensor(image2).to(self.device)
tensor1 = img2tensor(image1, self.device)
tensor2 = img2tensor(image2, self.device)
logging.debug(
f"Image shapes after conversion to tensors: {tensor1.shape}, {tensor2.shape}"
)
tensor1, tensor2 = check_dim_and_resize(tensor1, tensor2)
logging.debug(f"Image shapes after resizing: {tensor1.shape}, {tensor2.shape}")
h, w = tensor1.shape[2], tensor1.shape[3]
logging.debug(f"Interpolating images of size: {h}x{w}")
scale = self.scale(h, w)
logging.debug(f"Calculated scale factor: {scale:.2f}")
padding = int(16 / scale)
logging.debug(f"Calculated padding: {padding} pixels")
padder = InputPadder(tensor1.shape, divisor=padding)
tensor1_padded, tensor2_padded = padder.pad(tensor1, tensor2)
logging.debug(
f"Image shapes after padding: {tensor1_padded.shape}, {tensor2_padded.shape}"
)
tensor1_padded = tensor1_padded.to(self.device)
tensor2_padded = tensor2_padded.to(self.device)
logging.debug("Running model inference for interpolation")
with torch.no_grad():
with torch.amp.autocast(self.device.type):
interpolated = self.model_runner.model(
tensor1_padded, tensor2_padded, self.embt, scale_factor=scale, eval=True
tensor1, tensor2, self.embt
)["imgt_pred"]
logging.debug(f"Interpolated image shape before unpadding: {interpolated.shape}")
(interpolated,) = padder.unpad(interpolated)
logging.debug(f"Interpolated image shape after unpadding: {interpolated.shape}")
return tensor2img(interpolated.cpu())

View File

@@ -2,10 +2,10 @@ from typing import Optional
import torch
import torch.nn as nn
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from .blocks.feat_enc import LargeEncoder
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import LargeEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):

View File

@@ -1,10 +1,10 @@
import torch
import torch.nn as nn
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from .blocks.feat_enc import BasicEncoder
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import BasicEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):

View File

@@ -1,9 +1,9 @@
import torch
import torch.nn as nn
from .blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
from .blocks.feat_enc import SmallEncoder
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
from src.networks.blocks.feat_enc import SmallEncoder
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
class Model(nn.Module):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from ..utils.flow_utils import warp
from .blocks.ifrnet import convrelu, resize, ResBlock
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
class Encoder(nn.Module):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...utils.flow_utils import warp
from src.utils.flow_utils import warp
def resize(x, scale_factor):

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from ...utils.flow_utils import warp
from .ifrnet import convrelu, resize, ResBlock
from src.utils.flow_utils import warp
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
def multi_flow_combine(

View File

@@ -1,196 +0,0 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from cv2 import imwrite
import tqdm
from .config import presets
from .utils.fs import FileSystem
from .utils.video import VideoMaker
from .interpolator import (
ImageInterpolator,
Anchor,
get_device,
get_vram_available,
ModelRunner,
)
if TYPE_CHECKING:
import torch
import numpy as np
def performing_warning_message(device: "torch.device"):
if device.type in ("cpu", "mps"):
if device.type == "mps":
logging.warning(
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
)
else:
logging.warning(
"Running on CPU may be very slow. Consider using a GPU for better performance."
)
elif device.type == "cuda":
pass
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_fs(base_path: Path) -> FileSystem:
fs = FileSystem(base_path)
fs.clear_directory(fs.frames_path)
fs.clear_directory(fs.interpolated_path)
fs.clear_directory(fs.moved_path)
fs.clear_directory(fs.video_part_path)
return fs
def init_video_maker() -> VideoMaker:
return VideoMaker()
def init_device() -> "torch.device":
device = get_device()
performing_warning_message(device)
vram_available = get_vram_available(device)
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
return device
def init_anchor(device: "torch.device") -> Anchor:
if device.type in ("cpu", "mps"):
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
elif device.type == "cuda":
return Anchor(
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
)
else:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_interpolator(
model_runner: ModelRunner, device: "torch.device"
) -> ImageInterpolator:
anchor = init_anchor(device)
return ImageInterpolator(device, anchor, model_runner)
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames: tuple["np.ndarray", ...] = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
fps = self.video_maker.get_fps(video_path)
logging.info(f"Video FPS: {fps}")
fps *= 2 # Doubling FPS
width, height = self.video_maker.get_size(video_path)
for frames in self.video_maker.video_to_frames_generator(
video_path, self.fs.frames_path, chunk_seconds
):
logging.info(f"Processing frames: {len(frames)}")
if prev_frames:
img1 = prev_frames[-1]
img2 = frames[0]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
interpolated_frames = []
logging.info(f"Finished processing part {part:08d}")
part += 1
for i in tqdm.tqdm(
range(len(frames) - 1),
desc=f"Processing video frames {part + 1} / {total_parts}",
):
img1 = frames[i]
img2 = frames[i + 1]
img1_2 = self.interpolator.interpolate(img1, img2)
interpolated_frames.append(img1_2)
prev_frames = frames
generator = self._frame_generator(prev_frames, interpolated_frames)
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
self.video_maker.images_to_video_pipeline(
generator, part_path, width, height, fps
)
logging.info(f"Finished processing part {part:08d}")
self._merge_video_parts(self.fs.output_path / output_video)
logging.info(
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
)
def _save_images(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
logging.info("Saving images...")
self.fs.clear_directory(self.fs.moved_path)
index = 0
for i, frame in enumerate(source):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, frame)
if i < len(interpolated):
name = self.fs.moved_path / f"img_{index:08d}.png"
index += 1
imwrite(name, interpolated[i])
logging.info("Success...")
def _merge_frames_to_video(self, output_video: Path, fps: float):
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
def _merge_video_parts(self, output_video: Path):
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
self.fs.clear_directory(self.fs.video_part_path)
def _frame_generator(
self,
source: tuple["np.ndarray", ...],
interpolated: list["np.ndarray"],
):
for i, frame in enumerate(source):
yield frame
if i < len(interpolated):
yield interpolated[i]
def run(
base_path: Path,
video_path: Path,
output_video: str,
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
base_path=base_path,
)
pipeline.run(video_path, output_video)

View File

@@ -1,19 +1,16 @@
import importlib
from typing import TYPE_CHECKING
from ..networks import AMT_G, AMT_L, AMT_S
if TYPE_CHECKING:
from omegaconf import DictConfig
def build_from_cfg(config: "DictConfig"):
packages = {
"AMT-G": AMT_G,
"AMT-L": AMT_L,
"AMT-S": AMT_S
}
def base_build_fn(module: str, cls: str, params: dict):
return getattr(importlib.import_module(module, package=None), cls)(**params)
def build_from_cfg(config: "DictConfig"):
module, cls = config["name"].rsplit(".", 1)
params: dict = config.get("params", {})
return getattr(packages[module], cls)(**params)
return base_build_fn(module, cls, params)

View File

@@ -5,23 +5,26 @@ import numpy as np
def tensor2img(tensor: torch.Tensor):
return (
(tensor * 255.0)
.detach()
tensor = (
tensor.mul(255.0)
.clamp_(0, 255)
.to(torch.uint8)
.squeeze(0)
.permute(1, 2, 0)
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
return tensor.cpu().numpy()
def img2tensor(img: np.ndarray) -> torch.Tensor:
def img2tensor(img: np.ndarray, device: torch.device) -> torch.Tensor:
logging.debug(f"Converting image of shape {img.shape} to tensor")
if img.shape[-1] > 3:
img = img[:, :, :3]
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
if device.type != "cuda":
return tensor.float() / 255.0
return tensor.cuda(non_blocking=True).float().div_(255.0)
def check_dim_and_resize(*args: torch.Tensor) -> list[torch.Tensor]: