diff --git a/main.py b/main.py index f333138..68cc20c 100644 --- a/main.py +++ b/main.py @@ -176,7 +176,12 @@ class InterpolationPipeline: ) -def main(preset: presets.Preset = presets.LARGE): +def runner( + base_path: Path, + video_path: Path, + output_video: Path, + preset: presets.Preset = presets.LARGE, +): base_path = Path("output") video_path = Path("example/video.mp4") output_video = "interpolated_video.mp4" @@ -189,8 +194,39 @@ def main(preset: presets.Preset = presets.LARGE): pipeline.run(video_path, output_video) -if __name__ == "__main__": +def main(): + import argparse + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) + + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--base_path", help="Base path", default="output") + parser.add_argument( + "-v", "--video_path", help="Video path", default="example/video.mp4" + ) + parser.add_argument( + "-o", + "--output", + help="Output video name (example: 'interpolated_video.mp4')", + default="interpolated_video.mp4", + ) + parser.add_argument( + "-p", + "--preset", + help="Model preset", + choices=["small", "large", "global"], + default="global", + ) + args = parser.parse_args() + runner( + base_path=Path(args.base_path), + video_path=Path(args.video_path), + output_video=args.output, + preset=getattr(presets, args.preset.upper()) + ) + + +if __name__ == "__main__": main()