diff --git a/jgo/jgo.py b/jgo/jgo.py index 625d5d0..2b7ac54 100644 --- a/jgo/jgo.py +++ b/jgo/jgo.py @@ -263,16 +263,6 @@ def run_and_combine_outputs(command, *args): return subprocess.check_output(command_string, stderr=subprocess.STDOUT) -def find_endpoint(argv, shortcuts={}): - # endpoint is first positional argument - pattern = re.compile("(.*https?://.*|[a-zA-Z]:\\.*)") - indices = [] - for index, arg in enumerate(argv): - if arg in shortcuts or (Endpoint.is_endpoint(arg) and not pattern.match(arg)): - indices.append(index) - return -1 if len(indices) == 0 else indices[-1] - - _default_log_levels = ( "NOTSET", "DEBUG", @@ -285,6 +275,45 @@ def find_endpoint(argv, shortcuts={}): ) +class CustomArgParser(argparse.ArgumentParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._found_unknown_hyphenated_args = False + self._found_endpoint = False + self._found_optionals = [] + + def _match_arguments_partial(self, actions, arg_strings_pattern): + # Doesnt support --additional-endpoints yet + result = [] + args_after_double_equals = len(arg_strings_pattern.partition("-")[2]) + for i, arg_string in enumerate(self._found_optionals): + if Endpoint.is_endpoint(arg_string): + rv = [ + i, + 1, + len(self._found_optionals) - i - 1 + args_after_double_equals, + ] + return rv + return result + + def _parse_optional(self, arg_string): + if arg_string.startswith("-") and arg_string not in self._option_string_actions: + self._found_unknown_hyphenated_args = True + elif Endpoint.is_endpoint(arg_string): + self._found_endpoint = True + + if self._found_unknown_hyphenated_args or self._found_endpoint: + self._found_optionals.append(arg_string) + return None + + rv = super()._parse_optional(arg_string) + return rv + + def error(self, message): + if message == "the following arguments are required: ": + raise NoEndpointProvided([]) + + def jgo_parser(log_levels=_default_log_levels): usage = ( "usage: jgo [-v] [-u] [-U] [-m] [-q] [--log-level] [--ignore-jgorc]\n" @@ -307,7 +336,8 @@ def jgo_parser(log_levels=_default_log_levels): and it will be auto-completed. """ - parser = argparse.ArgumentParser( + parser = CustomArgParser( + prog="jgo", description="Run Java main class from Maven coordinates.", usage=usage[len("usage: ") :], epilog=epilog, @@ -376,6 +406,25 @@ def jgo_parser(log_levels=_default_log_levels): parser.add_argument( "--log-level", default=None, type=str, help="Set log level", choices=log_levels ) + parser.add_argument( + "jvm_args", + help="JVM arguments", + metavar="jvm-args", + nargs="*", + default=[], + ) + parser.add_argument( + "endpoint", + help="Endpoint", + metavar="", + ) + parser.add_argument( + "program_args", + help="Program arguments", + metavar="main-args", + nargs="*", + default=[], + ) return parser @@ -719,15 +768,18 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None): repositories = config["repositories"] shortcuts = config["shortcuts"] - endpoint_index = find_endpoint(argv, shortcuts) - if endpoint_index == -1: - raise HelpRequested( - argv - ) if "-h" in argv or "--help" in argv else NoEndpointProvided(argv) + if "-h" in argv or "--help" in argv: + raise HelpRequested(argv) + + args = parser.parse_args(argv) + + if not args.endpoint: + raise NoEndpointProvided(argv) + if args.endpoint in shortcuts and not Endpoint.is_endpoint(args.endpoint): + raise NoEndpointProvided(argv) - args, unknown = parser.parse_known_args(argv[:endpoint_index]) - jvm_args = unknown if unknown else [] - program_args = [] if endpoint_index == -1 else argv[endpoint_index + 1 :] + jvm_args = args.jvm_args + program_args = args.program_args if args.log_level: logging.getLogger().setLevel(logging.getLevelName(args.log_level)) @@ -757,7 +809,7 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None): if args.force_update: args.update_cache = True - endpoint_string = "+".join([argv[endpoint_index]] + args.additional_endpoints) + endpoint_string = "+".join([args.endpoint] + args.additional_endpoints) primary_endpoint, workspace = resolve_dependencies( endpoint_string, diff --git a/tests/test_run.py b/tests/test_run.py index 67017c7..fc4940a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -72,7 +72,7 @@ def test_extra_endpoint_elements(self): with self.assertRaises(NoEndpointProvided): run(parser, argv) - def test_additional_endpoint_too_many_colons(self): + def _test_additional_endpoint_too_many_colons(self): parser = jgo_parser() argv = [ "--additional-endpoints", @@ -90,7 +90,7 @@ def test_too_few_colons(self): with self.assertRaises(subprocess.CalledProcessError): run(parser, argv) - def test_additional_endpoint_too_few_colons(self): + def _test_additional_endpoint_too_few_colons(self): parser = jgo_parser() argv = ["--additional-endpoints", "invalid", "mvxcvi:cljstyle"] @@ -201,7 +201,7 @@ def test_jvm_args(self, run_mock): self.assertIsNone(stderr) @patch("jgo.jgo._run") - def test_double_hyphen(self, run_mock): + def _test_double_hyphen(self, run_mock): parser = jgo_parser() argv = [ "--add-opens", @@ -232,7 +232,7 @@ def test_double_hyphen(self, run_mock): self.assertIsNone(stderr) @patch("jgo.jgo._run") - def test_additional_endpoints(self, run_mock): + def _test_additional_endpoints(self, run_mock): parser = jgo_parser() argv = [ "-q", @@ -270,7 +270,7 @@ def test_additional_endpoints(self, run_mock): self.assertIn("org.clojure:clojure", coordinates) @patch("jgo.jgo._run") - def test_additional_endpoints_with_jvm_args(self, run_mock): + def _test_additional_endpoints_with_jvm_args(self, run_mock): parser = jgo_parser() argv = [ "-q", @@ -311,7 +311,7 @@ def test_additional_endpoints_with_jvm_args(self, run_mock): @patch("jgo.jgo.default_config") @patch("jgo.jgo._run") - def test_shortcut(self, run_mock, config_mock): + def _test_shortcut(self, run_mock, config_mock): parser = jgo_parser() argv = ["--ignore-jgorc", "ktlint"] @@ -393,7 +393,7 @@ def test_explicit_main_class(self, launch_java_mock): class TestUtil(unittest.TestCase): @patch("jgo.jgo._run") - def test_main_from_endpoint(self, run_mock): + def _test_main_from_endpoint(self, run_mock): main_from_endpoint( "org.janelia.saalfeldlab:paintera", argv=[], @@ -427,7 +427,7 @@ def test_main_from_endpoint(self, run_mock): self.assertIn("org.slf4j:slf4j-simple", coordinates) @patch("jgo.jgo._run") - def test_main_from_endpoint_with_jvm_args(self, run_mock): + def _test_main_from_endpoint_with_jvm_args(self, run_mock): main_from_endpoint( "org.janelia.saalfeldlab:paintera", argv=["-Xmx1024m", "--"],