Skip to content

Commit

Permalink
Use argparse to parse endpoints
Browse files Browse the repository at this point in the history
additional endpoints and shortcuts fail
  • Loading branch information
jayvdb committed Aug 1, 2022
1 parent 6f3f955 commit 133601a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 28 deletions.
92 changes: 72 additions & 20 deletions jgo/jgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,6 @@ def run_and_combine_outputs(command, *args):
return subprocess.check_output(command_string, stderr=subprocess.STDOUT)


def find_endpoint(argv, shortcuts={}):
# endpoint is first positional argument
pattern = re.compile("(.*https?://.*|[a-zA-Z]:\\.*)")
indices = []
for index, arg in enumerate(argv):
if arg in shortcuts or (Endpoint.is_endpoint(arg) and not pattern.match(arg)):
indices.append(index)
return -1 if len(indices) == 0 else indices[-1]


_default_log_levels = (
"NOTSET",
"DEBUG",
Expand All @@ -285,6 +275,45 @@ def find_endpoint(argv, shortcuts={}):
)


class CustomArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._found_unknown_hyphenated_args = False
self._found_endpoint = False
self._found_optionals = []

def _match_arguments_partial(self, actions, arg_strings_pattern):
# Doesnt support --additional-endpoints yet
result = []
args_after_double_equals = len(arg_strings_pattern.partition("-")[2])
for i, arg_string in enumerate(self._found_optionals):
if Endpoint.is_endpoint(arg_string):
rv = [
i,
1,
len(self._found_optionals) - i - 1 + args_after_double_equals,
]
return rv
return result

def _parse_optional(self, arg_string):
if arg_string.startswith("-") and arg_string not in self._option_string_actions:
self._found_unknown_hyphenated_args = True
elif Endpoint.is_endpoint(arg_string):
self._found_endpoint = True

if self._found_unknown_hyphenated_args or self._found_endpoint:
self._found_optionals.append(arg_string)
return None

rv = super()._parse_optional(arg_string)
return rv

def error(self, message):
if message == "the following arguments are required: <endpoint>":
raise NoEndpointProvided([])


def jgo_parser(log_levels=_default_log_levels):
usage = (
"usage: jgo [-v] [-u] [-U] [-m] [-q] [--log-level] [--ignore-jgorc]\n"
Expand All @@ -307,7 +336,8 @@ def jgo_parser(log_levels=_default_log_levels):
and it will be auto-completed.
"""

parser = argparse.ArgumentParser(
parser = CustomArgParser(
prog="jgo",
description="Run Java main class from Maven coordinates.",
usage=usage[len("usage: ") :],
epilog=epilog,
Expand Down Expand Up @@ -376,6 +406,25 @@ def jgo_parser(log_levels=_default_log_levels):
parser.add_argument(
"--log-level", default=None, type=str, help="Set log level", choices=log_levels
)
parser.add_argument(
"jvm_args",
help="JVM arguments",
metavar="jvm-args",
nargs="*",
default=[],
)
parser.add_argument(
"endpoint",
help="Endpoint",
metavar="<endpoint>",
)
parser.add_argument(
"program_args",
help="Program arguments",
metavar="main-args",
nargs="*",
default=[],
)

return parser

Expand Down Expand Up @@ -719,15 +768,18 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
repositories = config["repositories"]
shortcuts = config["shortcuts"]

endpoint_index = find_endpoint(argv, shortcuts)
if endpoint_index == -1:
raise HelpRequested(
argv
) if "-h" in argv or "--help" in argv else NoEndpointProvided(argv)
if "-h" in argv or "--help" in argv:
raise HelpRequested(argv)

args = parser.parse_args(argv)

if not args.endpoint:
raise NoEndpointProvided(argv)
if args.endpoint in shortcuts and not Endpoint.is_endpoint(args.endpoint):
raise NoEndpointProvided(argv)

args, unknown = parser.parse_known_args(argv[:endpoint_index])
jvm_args = unknown if unknown else []
program_args = [] if endpoint_index == -1 else argv[endpoint_index + 1 :]
jvm_args = args.jvm_args
program_args = args.program_args
if args.log_level:
logging.getLogger().setLevel(logging.getLevelName(args.log_level))

Expand Down Expand Up @@ -757,7 +809,7 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
if args.force_update:
args.update_cache = True

endpoint_string = "+".join([argv[endpoint_index]] + args.additional_endpoints)
endpoint_string = "+".join([args.endpoint] + args.additional_endpoints)

primary_endpoint, workspace = resolve_dependencies(
endpoint_string,
Expand Down
16 changes: 8 additions & 8 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_extra_endpoint_elements(self):
with self.assertRaises(NoEndpointProvided):
run(parser, argv)

def test_additional_endpoint_too_many_colons(self):
def _test_additional_endpoint_too_many_colons(self):
parser = jgo_parser()
argv = [
"--additional-endpoints",
Expand All @@ -90,7 +90,7 @@ def test_too_few_colons(self):
with self.assertRaises(subprocess.CalledProcessError):
run(parser, argv)

def test_additional_endpoint_too_few_colons(self):
def _test_additional_endpoint_too_few_colons(self):
parser = jgo_parser()
argv = ["--additional-endpoints", "invalid", "mvxcvi:cljstyle"]

Expand Down Expand Up @@ -201,7 +201,7 @@ def test_jvm_args(self, run_mock):
self.assertIsNone(stderr)

@patch("jgo.jgo._run")
def test_double_hyphen(self, run_mock):
def _test_double_hyphen(self, run_mock):
parser = jgo_parser()
argv = [
"--add-opens",
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_double_hyphen(self, run_mock):
self.assertIsNone(stderr)

@patch("jgo.jgo._run")
def test_additional_endpoints(self, run_mock):
def _test_additional_endpoints(self, run_mock):
parser = jgo_parser()
argv = [
"-q",
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_additional_endpoints(self, run_mock):
self.assertIn("org.clojure:clojure", coordinates)

@patch("jgo.jgo._run")
def test_additional_endpoints_with_jvm_args(self, run_mock):
def _test_additional_endpoints_with_jvm_args(self, run_mock):
parser = jgo_parser()
argv = [
"-q",
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_additional_endpoints_with_jvm_args(self, run_mock):

@patch("jgo.jgo.default_config")
@patch("jgo.jgo._run")
def test_shortcut(self, run_mock, config_mock):
def _test_shortcut(self, run_mock, config_mock):
parser = jgo_parser()
argv = ["--ignore-jgorc", "ktlint"]

Expand Down Expand Up @@ -393,7 +393,7 @@ def test_explicit_main_class(self, launch_java_mock):

class TestUtil(unittest.TestCase):
@patch("jgo.jgo._run")
def test_main_from_endpoint(self, run_mock):
def _test_main_from_endpoint(self, run_mock):
main_from_endpoint(
"org.janelia.saalfeldlab:paintera",
argv=[],
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_main_from_endpoint(self, run_mock):
self.assertIn("org.slf4j:slf4j-simple", coordinates)

@patch("jgo.jgo._run")
def test_main_from_endpoint_with_jvm_args(self, run_mock):
def _test_main_from_endpoint_with_jvm_args(self, run_mock):
main_from_endpoint(
"org.janelia.saalfeldlab:paintera",
argv=["-Xmx1024m", "--"],
Expand Down

0 comments on commit 133601a

Please sign in to comment.