Попытка добавить onnx в работе с nvidia

This commit is contained in:
Viner Abubakirov
2026-04-15 17:53:06 +05:00
parent 0c871c2314
commit 7addcf051c
13 changed files with 2484 additions and 371 deletions

16
main.py
View File

@@ -70,10 +70,8 @@ def init_anchor(device: "torch.device") -> Anchor:
raise Exception(f"Unsupported device type: {device.type}")
def init_model_runner(
config: Path, checkpoint_path: Path, device: "torch.device"
) -> ModelRunner:
return ModelRunner(config, checkpoint_path, device)
def init_model_runner(preset: presets.Preset, device: "torch.device") -> ModelRunner:
return ModelRunner(preset, device)
def init_interpolator(
@@ -86,21 +84,20 @@ def init_interpolator(
class InterpolationPipeline:
def __init__(
self,
config: Path,
checkpoint_path: Path,
preset: presets.Preset,
base_path: Path,
):
self.fs = init_fs(base_path)
self.video_maker = init_video_maker()
self.device = init_device()
self.model_runner = init_model_runner(config, checkpoint_path, self.device)
self.model_runner = init_model_runner(preset, self.device)
self.interpolator = init_interpolator(self.model_runner, self.device)
def run(self, video_path: Path, output_video: str):
prev_frames = tuple()
interpolated_frames: list["np.ndarray"] = []
part = 0
chunk_seconds = 10
chunk_seconds = 1
length = self.video_maker.get_video_duration(video_path)
last_part_seconds = 1 if length % chunk_seconds else 0
total_parts = int(length // chunk_seconds) + last_part_seconds
@@ -189,8 +186,7 @@ def runner(
preset: presets.Preset = presets.LARGE,
):
pipeline = InterpolationPipeline(
config=preset.config,
checkpoint_path=preset.checkpoint,
preset=preset,
base_path=base_path,
)
pipeline.run(video_path, output_video)