Compare commits
2 Commits
c91cf6b53a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7acd66974 | ||
|
|
2d67b72128 |
196
main.py
196
main.py
@@ -1,199 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from src.runner import run
|
||||||
|
|
||||||
from cv2 import imwrite
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from src.config import presets
|
from src.config import presets
|
||||||
from src.utils.fs import FileSystem
|
|
||||||
from src.utils.video import VideoMaker
|
|
||||||
from src.interpolator import (
|
|
||||||
ImageInterpolator,
|
|
||||||
Anchor,
|
|
||||||
get_device,
|
|
||||||
get_vram_available,
|
|
||||||
ModelRunner,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def performing_warning_message(device: "torch.device"):
|
|
||||||
if device.type in ("cpu", "mps"):
|
|
||||||
if device.type == "mps":
|
|
||||||
logging.warning(
|
|
||||||
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.warning(
|
|
||||||
"Running on CPU may be very slow. Consider using a GPU for better performance."
|
|
||||||
)
|
|
||||||
elif device.type == "cuda":
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unsupported device type: {device.type}")
|
|
||||||
|
|
||||||
|
|
||||||
def init_fs(base_path: Path) -> FileSystem:
|
|
||||||
fs = FileSystem(base_path)
|
|
||||||
fs.clear_directory(fs.frames_path)
|
|
||||||
fs.clear_directory(fs.interpolated_path)
|
|
||||||
fs.clear_directory(fs.moved_path)
|
|
||||||
fs.clear_directory(fs.video_part_path)
|
|
||||||
return fs
|
|
||||||
|
|
||||||
|
|
||||||
def init_video_maker() -> VideoMaker:
|
|
||||||
return VideoMaker()
|
|
||||||
|
|
||||||
|
|
||||||
def init_device() -> "torch.device":
|
|
||||||
device = get_device()
|
|
||||||
performing_warning_message(device)
|
|
||||||
vram_available = get_vram_available(device)
|
|
||||||
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def init_anchor(device: "torch.device") -> Anchor:
|
|
||||||
if device.type in ("cpu", "mps"):
|
|
||||||
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
|
|
||||||
elif device.type == "cuda":
|
|
||||||
return Anchor(
|
|
||||||
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unsupported device type: {device.type}")
|
|
||||||
|
|
||||||
|
|
||||||
def init_model_runner(
|
|
||||||
config: Path, checkpoint_path: Path, device: "torch.device"
|
|
||||||
) -> ModelRunner:
|
|
||||||
return ModelRunner(config, checkpoint_path, device)
|
|
||||||
|
|
||||||
|
|
||||||
def init_interpolator(
|
|
||||||
model_runner: ModelRunner, device: "torch.device"
|
|
||||||
) -> ImageInterpolator:
|
|
||||||
anchor = init_anchor(device)
|
|
||||||
return ImageInterpolator(device, anchor, model_runner)
|
|
||||||
|
|
||||||
|
|
||||||
class InterpolationPipeline:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Path,
|
|
||||||
checkpoint_path: Path,
|
|
||||||
base_path: Path,
|
|
||||||
):
|
|
||||||
self.fs = init_fs(base_path)
|
|
||||||
self.video_maker = init_video_maker()
|
|
||||||
self.device = init_device()
|
|
||||||
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
|
|
||||||
self.interpolator = init_interpolator(self.model_runner, self.device)
|
|
||||||
|
|
||||||
def run(self, video_path: Path, output_video: str):
|
|
||||||
prev_frames = tuple()
|
|
||||||
interpolated_frames: list["np.ndarray"] = []
|
|
||||||
part = 0
|
|
||||||
chunk_seconds = 10
|
|
||||||
length = self.video_maker.get_video_duration(video_path)
|
|
||||||
last_part_seconds = 1 if length % chunk_seconds else 0
|
|
||||||
total_parts = int(length // chunk_seconds) + last_part_seconds
|
|
||||||
fps = self.video_maker.get_fps(video_path)
|
|
||||||
logging.info(f"Video FPS: {fps}")
|
|
||||||
fps *= 2 # Doubling FPS
|
|
||||||
width, height = self.video_maker.get_size(video_path)
|
|
||||||
for frames in self.video_maker.video_to_frames_generator(
|
|
||||||
video_path, self.fs.frames_path, chunk_seconds
|
|
||||||
):
|
|
||||||
logging.info(f"Processing frames: {len(frames)}")
|
|
||||||
if prev_frames:
|
|
||||||
img1 = prev_frames[-1]
|
|
||||||
img2 = frames[0]
|
|
||||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
|
||||||
interpolated_frames.append(img1_2)
|
|
||||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
|
||||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
|
||||||
self.video_maker.images_to_video_pipeline(
|
|
||||||
generator, part_path, width, height, fps
|
|
||||||
)
|
|
||||||
interpolated_frames = []
|
|
||||||
logging.info(f"Finished processing part {part:08d}")
|
|
||||||
part += 1
|
|
||||||
for i in tqdm.tqdm(
|
|
||||||
range(len(frames) - 1),
|
|
||||||
desc=f"Processing video frames {part + 1} / {total_parts}",
|
|
||||||
):
|
|
||||||
img1 = frames[i]
|
|
||||||
img2 = frames[i + 1]
|
|
||||||
img1_2 = self.interpolator.interpolate(img1, img2)
|
|
||||||
interpolated_frames.append(img1_2)
|
|
||||||
prev_frames = frames
|
|
||||||
|
|
||||||
generator = self._frame_generator(prev_frames, interpolated_frames)
|
|
||||||
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
|
||||||
self.video_maker.images_to_video_pipeline(
|
|
||||||
generator, part_path, width, height, fps
|
|
||||||
)
|
|
||||||
logging.info(f"Finished processing part {part:08d}")
|
|
||||||
self._merge_video_parts(self.fs.output_path / output_video)
|
|
||||||
logging.info(
|
|
||||||
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _save_images(
|
|
||||||
self,
|
|
||||||
source: tuple["np.ndarray", ...],
|
|
||||||
interpolated: list["np.ndarray"],
|
|
||||||
):
|
|
||||||
logging.info("Saving images...")
|
|
||||||
self.fs.clear_directory(self.fs.moved_path)
|
|
||||||
index = 0
|
|
||||||
for i, frame in enumerate(source):
|
|
||||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
|
||||||
index += 1
|
|
||||||
imwrite(name, frame)
|
|
||||||
if i < len(interpolated):
|
|
||||||
name = self.fs.moved_path / f"img_{index:08d}.png"
|
|
||||||
index += 1
|
|
||||||
imwrite(name, interpolated[i])
|
|
||||||
logging.info("Success...")
|
|
||||||
|
|
||||||
def _merge_frames_to_video(self, output_video: Path, fps: float):
|
|
||||||
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
|
||||||
|
|
||||||
def _merge_video_parts(self, output_video: Path):
|
|
||||||
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
|
||||||
self.fs.clear_directory(self.fs.video_part_path)
|
|
||||||
|
|
||||||
def _frame_generator(
|
|
||||||
self,
|
|
||||||
source: tuple["np.ndarray", ...],
|
|
||||||
interpolated: list["np.ndarray"],
|
|
||||||
):
|
|
||||||
for i, frame in enumerate(source):
|
|
||||||
yield frame
|
|
||||||
if i < len(interpolated):
|
|
||||||
yield interpolated[i]
|
|
||||||
|
|
||||||
|
|
||||||
def runner(
|
|
||||||
base_path: Path,
|
|
||||||
video_path: Path,
|
|
||||||
output_video: str,
|
|
||||||
preset: presets.Preset = presets.LARGE,
|
|
||||||
):
|
|
||||||
pipeline = InterpolationPipeline(
|
|
||||||
config=preset.config,
|
|
||||||
checkpoint_path=preset.checkpoint,
|
|
||||||
base_path=base_path,
|
|
||||||
)
|
|
||||||
pipeline.run(video_path, output_video)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -222,7 +30,7 @@ def main():
|
|||||||
default="global",
|
default="global",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
runner(
|
run(
|
||||||
base_path=Path(args.base_path),
|
base_path=Path(args.base_path),
|
||||||
video_path=Path(args.video_path),
|
video_path=Path(args.video_path),
|
||||||
output_video=args.output,
|
output_video=args.output,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
|||||||
eval_interval: 1
|
eval_interval: 1
|
||||||
|
|
||||||
network:
|
network:
|
||||||
name: src.networks.AMT-G.Model
|
name: AMT-G.Model
|
||||||
params:
|
params:
|
||||||
corr_radius: 3
|
corr_radius: 3
|
||||||
corr_lvls: 4
|
corr_lvls: 4
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
|||||||
eval_interval: 1
|
eval_interval: 1
|
||||||
|
|
||||||
network:
|
network:
|
||||||
name: src.networks.AMT-L.Model
|
name: AMT-L.Model
|
||||||
params:
|
params:
|
||||||
corr_radius: 3
|
corr_radius: 3
|
||||||
corr_lvls: 4
|
corr_lvls: 4
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ save_dir: work_dir
|
|||||||
eval_interval: 1
|
eval_interval: 1
|
||||||
|
|
||||||
network:
|
network:
|
||||||
name: src.networks.AMT-S.Model
|
name: AMT-S.Model
|
||||||
params:
|
params:
|
||||||
corr_radius: 3
|
corr_radius: 3
|
||||||
corr_lvls: 4
|
corr_lvls: 4
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
|
||||||
from src.utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
from .utils.torch import img2tensor, check_dim_and_resize, tensor2img
|
||||||
from src.utils.build import build_from_cfg
|
from .utils.build import build_from_cfg
|
||||||
from src.utils.padder import InputPadder
|
from .utils.padder import InputPadder
|
||||||
|
|
||||||
|
|
||||||
class Anchor:
|
class Anchor:
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||||
from src.networks.blocks.feat_enc import LargeEncoder
|
from .blocks.feat_enc import LargeEncoder
|
||||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.networks.blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
from .blocks.raft import coords_grid, BasicUpdateBlock, BidirCorrBlock
|
||||||
from src.networks.blocks.feat_enc import BasicEncoder
|
from .blocks.feat_enc import BasicEncoder
|
||||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||||
|
|
||||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.networks.blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
|
from .blocks.raft import coords_grid, SmallUpdateBlock, BidirCorrBlock
|
||||||
from src.networks.blocks.feat_enc import SmallEncoder
|
from .blocks.feat_enc import SmallEncoder
|
||||||
from src.networks.blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
from .blocks.ifrnet import resize, Encoder, InitDecoder, IntermediateDecoder
|
||||||
from src.networks.blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
from .blocks.multi_flow import multi_flow_combine, MultiFlowDecoder
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.utils.flow_utils import warp
|
from ..utils.flow_utils import warp
|
||||||
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
|
from .blocks.ifrnet import convrelu, resize, ResBlock
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from src.utils.flow_utils import warp
|
from ...utils.flow_utils import warp
|
||||||
|
|
||||||
|
|
||||||
def resize(x, scale_factor):
|
def resize(x, scale_factor):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.utils.flow_utils import warp
|
from ...utils.flow_utils import warp
|
||||||
from src.networks.blocks.ifrnet import convrelu, resize, ResBlock
|
from .ifrnet import convrelu, resize, ResBlock
|
||||||
|
|
||||||
|
|
||||||
def multi_flow_combine(
|
def multi_flow_combine(
|
||||||
|
|||||||
196
src/runner.py
Normal file
196
src/runner.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cv2 import imwrite
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from .config import presets
|
||||||
|
from .utils.fs import FileSystem
|
||||||
|
from .utils.video import VideoMaker
|
||||||
|
from .interpolator import (
|
||||||
|
ImageInterpolator,
|
||||||
|
Anchor,
|
||||||
|
get_device,
|
||||||
|
get_vram_available,
|
||||||
|
ModelRunner,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def performing_warning_message(device: "torch.device"):
|
||||||
|
if device.type in ("cpu", "mps"):
|
||||||
|
if device.type == "mps":
|
||||||
|
logging.warning(
|
||||||
|
"Running on Apple Silicon GPU (MPS) may have limited performance. Consider using a CUDA-enabled GPU for better performance."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"Running on CPU may be very slow. Consider using a GPU for better performance."
|
||||||
|
)
|
||||||
|
elif device.type == "cuda":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported device type: {device.type}")
|
||||||
|
|
||||||
|
|
||||||
|
def init_fs(base_path: Path) -> FileSystem:
|
||||||
|
fs = FileSystem(base_path)
|
||||||
|
fs.clear_directory(fs.frames_path)
|
||||||
|
fs.clear_directory(fs.interpolated_path)
|
||||||
|
fs.clear_directory(fs.moved_path)
|
||||||
|
fs.clear_directory(fs.video_part_path)
|
||||||
|
return fs
|
||||||
|
|
||||||
|
|
||||||
|
def init_video_maker() -> VideoMaker:
|
||||||
|
return VideoMaker()
|
||||||
|
|
||||||
|
|
||||||
|
def init_device() -> "torch.device":
|
||||||
|
device = get_device()
|
||||||
|
performing_warning_message(device)
|
||||||
|
vram_available = get_vram_available(device)
|
||||||
|
logging.info(f"Available VRAM: {vram_available / (1024**3):.2f} GB")
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def init_anchor(device: "torch.device") -> Anchor:
|
||||||
|
if device.type in ("cpu", "mps"):
|
||||||
|
return Anchor(resolution=8192 * 8192, memory=1, memory_bias=0)
|
||||||
|
elif device.type == "cuda":
|
||||||
|
return Anchor(
|
||||||
|
resolution=1024 * 512, memory=1500 * 1024**2, memory_bias=2500 * 1024**2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported device type: {device.type}")
|
||||||
|
|
||||||
|
|
||||||
|
def init_model_runner(
|
||||||
|
config: Path, checkpoint_path: Path, device: "torch.device"
|
||||||
|
) -> ModelRunner:
|
||||||
|
return ModelRunner(config, checkpoint_path, device)
|
||||||
|
|
||||||
|
|
||||||
|
def init_interpolator(
|
||||||
|
model_runner: ModelRunner, device: "torch.device"
|
||||||
|
) -> ImageInterpolator:
|
||||||
|
anchor = init_anchor(device)
|
||||||
|
return ImageInterpolator(device, anchor, model_runner)
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolationPipeline:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Path,
|
||||||
|
checkpoint_path: Path,
|
||||||
|
base_path: Path,
|
||||||
|
):
|
||||||
|
self.fs = init_fs(base_path)
|
||||||
|
self.video_maker = init_video_maker()
|
||||||
|
self.device = init_device()
|
||||||
|
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
|
||||||
|
self.interpolator = init_interpolator(self.model_runner, self.device)
|
||||||
|
|
||||||
|
def run(self, video_path: Path, output_video: str):
|
||||||
|
prev_frames: tuple["np.ndarray", ...] = tuple()
|
||||||
|
interpolated_frames: list["np.ndarray"] = []
|
||||||
|
part = 0
|
||||||
|
chunk_seconds = 10
|
||||||
|
length = self.video_maker.get_video_duration(video_path)
|
||||||
|
last_part_seconds = 1 if length % chunk_seconds else 0
|
||||||
|
total_parts = int(length // chunk_seconds) + last_part_seconds
|
||||||
|
fps = self.video_maker.get_fps(video_path)
|
||||||
|
logging.info(f"Video FPS: {fps}")
|
||||||
|
fps *= 2 # Doubling FPS
|
||||||
|
width, height = self.video_maker.get_size(video_path)
|
||||||
|
for frames in self.video_maker.video_to_frames_generator(
|
||||||
|
video_path, self.fs.frames_path, chunk_seconds
|
||||||
|
):
|
||||||
|
logging.info(f"Processing frames: {len(frames)}")
|
||||||
|
if prev_frames:
|
||||||
|
img1 = prev_frames[-1]
|
||||||
|
img2 = frames[0]
|
||||||
|
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||||
|
interpolated_frames.append(img1_2)
|
||||||
|
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||||
|
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||||
|
self.video_maker.images_to_video_pipeline(
|
||||||
|
generator, part_path, width, height, fps
|
||||||
|
)
|
||||||
|
interpolated_frames = []
|
||||||
|
logging.info(f"Finished processing part {part:08d}")
|
||||||
|
part += 1
|
||||||
|
for i in tqdm.tqdm(
|
||||||
|
range(len(frames) - 1),
|
||||||
|
desc=f"Processing video frames {part + 1} / {total_parts}",
|
||||||
|
):
|
||||||
|
img1 = frames[i]
|
||||||
|
img2 = frames[i + 1]
|
||||||
|
img1_2 = self.interpolator.interpolate(img1, img2)
|
||||||
|
interpolated_frames.append(img1_2)
|
||||||
|
prev_frames = frames
|
||||||
|
|
||||||
|
generator = self._frame_generator(prev_frames, interpolated_frames)
|
||||||
|
part_path = self.fs.video_part_path / f"video_{part:08d}.mp4"
|
||||||
|
self.video_maker.images_to_video_pipeline(
|
||||||
|
generator, part_path, width, height, fps
|
||||||
|
)
|
||||||
|
logging.info(f"Finished processing part {part:08d}")
|
||||||
|
self._merge_video_parts(self.fs.output_path / output_video)
|
||||||
|
logging.info(
|
||||||
|
f"Video interpolation completed. Output saved to: {self.fs.output_path / output_video}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_images(
|
||||||
|
self,
|
||||||
|
source: tuple["np.ndarray", ...],
|
||||||
|
interpolated: list["np.ndarray"],
|
||||||
|
):
|
||||||
|
logging.info("Saving images...")
|
||||||
|
self.fs.clear_directory(self.fs.moved_path)
|
||||||
|
index = 0
|
||||||
|
for i, frame in enumerate(source):
|
||||||
|
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||||
|
index += 1
|
||||||
|
imwrite(name, frame)
|
||||||
|
if i < len(interpolated):
|
||||||
|
name = self.fs.moved_path / f"img_{index:08d}.png"
|
||||||
|
index += 1
|
||||||
|
imwrite(name, interpolated[i])
|
||||||
|
logging.info("Success...")
|
||||||
|
|
||||||
|
def _merge_frames_to_video(self, output_video: Path, fps: float):
|
||||||
|
self.video_maker.images_to_video(self.fs.moved_path, output_video, fps)
|
||||||
|
|
||||||
|
def _merge_video_parts(self, output_video: Path):
|
||||||
|
self.video_maker.concatenate_videos(self.fs.video_part_path, output_video)
|
||||||
|
self.fs.clear_directory(self.fs.video_part_path)
|
||||||
|
|
||||||
|
def _frame_generator(
|
||||||
|
self,
|
||||||
|
source: tuple["np.ndarray", ...],
|
||||||
|
interpolated: list["np.ndarray"],
|
||||||
|
):
|
||||||
|
for i, frame in enumerate(source):
|
||||||
|
yield frame
|
||||||
|
if i < len(interpolated):
|
||||||
|
yield interpolated[i]
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
base_path: Path,
|
||||||
|
video_path: Path,
|
||||||
|
output_video: str,
|
||||||
|
preset: presets.Preset = presets.LARGE,
|
||||||
|
):
|
||||||
|
pipeline = InterpolationPipeline(
|
||||||
|
config=preset.config,
|
||||||
|
checkpoint_path=preset.checkpoint,
|
||||||
|
base_path=base_path,
|
||||||
|
)
|
||||||
|
pipeline.run(video_path, output_video)
|
||||||
@@ -1,16 +1,19 @@
|
|||||||
import importlib
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from ..networks import AMT_G, AMT_L, AMT_S
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
|
||||||
def base_build_fn(module: str, cls: str, params: dict):
|
|
||||||
return getattr(importlib.import_module(module, package=None), cls)(**params)
|
|
||||||
|
|
||||||
|
|
||||||
def build_from_cfg(config: "DictConfig"):
|
def build_from_cfg(config: "DictConfig"):
|
||||||
|
packages = {
|
||||||
|
"AMT-G": AMT_G,
|
||||||
|
"AMT-L": AMT_L,
|
||||||
|
"AMT-S": AMT_S
|
||||||
|
}
|
||||||
|
|
||||||
module, cls = config["name"].rsplit(".", 1)
|
module, cls = config["name"].rsplit(".", 1)
|
||||||
params: dict = config.get("params", {})
|
params: dict = config.get("params", {})
|
||||||
return base_build_fn(module, cls, params)
|
return getattr(packages[module], cls)(**params)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user