From 70f24b6e2c876377401855945f04acf94d01b572 Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Sat, 20 Apr 2024 18:13:44 -0500 Subject: [PATCH 1/2] Update command func to return Error, added tests. (#14) * test scaffolding * add github workflow * match_all not working for now * more scaffolding * more tests * add gitignore * fix test * update readme --- .github/workflows/test.yml | 23 +++++ .gitignore | 160 ++++++++++++++++++++++++++++++ README.md | 80 +++++++++------ examples/hello_world/printer.mojo | 5 +- examples/hello_world/root.mojo | 8 +- examples/hello_world/say.mojo | 9 +- examples/logging/root.mojo | 8 +- examples/nested/nested.mojo | 75 ++++++++------ examples/printer/printer.mojo | 15 +-- examples/read_csv/root.mojo | 25 +++-- prism/__init__.mojo | 10 +- prism/args.mojo | 10 +- prism/command.mojo | 147 +++++++++++++++------------ prism/flag.mojo | 22 ++-- tests/__init__.mojo | 0 tests/test_args.mojo | 74 ++++++++++++++ tests/test_command.mojo | 32 ++++++ tests/test_flags.mojo | 35 +++++++ tests/wrapper.mojo | 44 ++++++++ 19 files changed, 608 insertions(+), 174 deletions(-) create mode 100644 .github/workflows/test.yml create mode 100644 .gitignore create mode 100644 tests/__init__.mojo create mode 100644 tests/test_args.mojo create mode 100644 tests/test_command.mojo create mode 100644 tests/test_flags.mojo create mode 100644 tests/wrapper.mojo diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..2bb0f1b --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: Run Tests + +on: ["push"] + +jobs: + test: + runs-on: ubuntu-latest + environment: basic + steps: + - name: Check out repository code + uses: actions/checkout@v2 + - name: Install dependencies + run: | + curl https://get.modular.com | MODULAR_AUTH=${{ secrets.MODULAR_AUTH }} sh - + modular auth ${{ secrets.MODULAR_AUTH }} + modular install nightly/mojo + pip install pytest + pip install git+https://github.com/guidorice/mojo-pytest.git + - name: Unit Tests + run: | + export MODULAR_HOME="/home/runner/.modular" + export PATH="/home/runner/.modular/pkg/packages.modular.com_nightly_mojo/bin:$PATH" + pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index abf3d13..a0d7b20 100644 --- a/README.md +++ b/README.md @@ -19,55 +19,68 @@ from prism import Flag, Command, CommandArc from python import Python, PythonObject -fn base(command: CommandArc, args: List[String]) raises -> None: +fn base(command: CommandArc, args: List[String]) -> Error: print("This is the base command!") + return Error() -fn print_information(command: CommandArc, args: List[String]) raises -> None: +fn print_information(command: CommandArc, args: List[String]) -> Error: print("Pass cat or dog as a subcommand, and see what you get!") + return Error() -fn get_cat_fact(command: CommandArc, args: List[String]) raises -> None: +fn get_cat_fact(command: CommandArc, args: List[String]) -> Error: var flags = command[].get_all_flags()[] var lover = flags.get_as_bool("lover") if lover and lover.value(): print("Hello fellow cat lover!") - var requests = Python.import_module("requests") - # URL you want to send a GET request to - var url = "https://cat-fact.herokuapp.com/facts/" + try: + var requests = Python.import_module("requests") - # Send the GET request - var response = requests.get(url) + # URL you want to send a GET request to + var url = "https://cat-fact.herokuapp.com/facts/" - # Check if the request was successful (status code 200) - if response.status_code == 200: - var count = flags.get_as_int("count") - if not count: - raise Error("Count flag was not found.") - var body = response.json() - for i in range(count.value()): - print(body[i]["text"]) - else: - raise Error("Request failed!") + # Send the GET request + var response = requests.get(url) + # Check if the request was successful (status code 200) + if response.status_code == 200: + var count = flags.get_as_int("count") + if not count: + return Error("Count flag was not found.") + var body = response.json() + for i in range(count.value()): + print(body[i]["text"]) + else: + return Error("Request failed!") + except e: + return e -fn get_dog_breeds(command: CommandArc, args: List[String]) raises -> None: - var requests = Python.import_module("requests") - # URL you want to send a GET request to - var url = "https://dog.ceo/api/breeds/list/all" + return Error() - # Send the GET request - var response = requests.get(url) - # Check if the request was successful (status code 200) - if response.status_code == 200: - print(response.json()["message"]) - else: - raise Error("Request failed!") +fn get_dog_breeds(command: CommandArc, args: List[String]) -> Error: + try: + var requests = Python.import_module("requests") + # URL you want to send a GET request to + var url = "https://dog.ceo/api/breeds/list/all" + # Send the GET request + var response = requests.get(url) -fn init() raises -> None: + # Check if the request was successful (status code 200) + if response.status_code == 200: + print(response.json()["message"]) + else: + return Error("Request failed!") + except e: + return e + + return Error() + + +fn init() -> None: var root_command = Command(name="nested", description="Base command.", run=base) var get_command = Command( @@ -96,8 +109,9 @@ fn init() raises -> None: root_command.execute() -fn main() raises -> None: +fn main() -> None: init() + ``` Start by navigating to the `nested` example directory. @@ -168,6 +182,8 @@ Usage information will be printed the console by passing the `--help` flag. - Flags can have values passed by using the `=` operator. Like `--count=5` OR like `--count 5`. - Commands can be created via a typical `Command()` constructor to use runtime values, or you can use `Command.new()` method to create a new `Command` using compile time `Parameters` instead (when possible). +- This library leans towards Errors as values over raising Exceptions. +- `Optional[Error]` would be much cleaner for Command run functions. For now return `Error()` if there's no `Error` to return. ## TODO @@ -183,6 +199,8 @@ Usage information will be printed the console by passing the `--help` flag. - Map `--help` flag to configurable help function. - Add find suggestion logic to `Command` struct. - Enable required flags. +- Replace print usage with writers to enable stdout/stderr/file writing. +- Split `Run` and `RunE` fields so that the primary run function `Run` can return no errors while `RunE` can return errors. ### Improvements diff --git a/examples/hello_world/printer.mojo b/examples/hello_world/printer.mojo index 5c6d2c1..fffbc0b 100644 --- a/examples/hello_world/printer.mojo +++ b/examples/hello_world/printer.mojo @@ -2,12 +2,13 @@ from prism import Flag, Command from prism.command import CommandArc -fn printer(command: CommandArc, args: List[String]) raises -> None: +fn printer(command: CommandArc, args: List[String]) -> Error: if len(args) == 0: print("No args provided.") - return None + return Error() print(args[0]) + return Error() fn build_printer_command() -> Command: diff --git a/examples/hello_world/root.mojo b/examples/hello_world/root.mojo index be51611..b46d51d 100644 --- a/examples/hello_world/root.mojo +++ b/examples/hello_world/root.mojo @@ -11,12 +11,14 @@ from memory._arc import Arc # TODO: Using CommandArc instead of Arc[Command] works. But using Arc[Command] causes a recursive relationship error? -fn test(command: CommandArc, args: List[String]) raises -> None: +fn test(command: CommandArc, args: List[String]) -> Error: for item in command[].get_all_flags()[].flags: print(item[].name, item[].value.value()) + return Error() -fn init() raises -> None: + +fn init() -> None: var root_command = Command( name="tones", description="This is a dummy command!", @@ -38,5 +40,5 @@ fn init() raises -> None: root_command.execute() -fn main() raises -> None: +fn main() -> None: init() diff --git a/examples/hello_world/say.mojo b/examples/hello_world/say.mojo index 4155b2f..6e3f6af 100644 --- a/examples/hello_world/say.mojo +++ b/examples/hello_world/say.mojo @@ -2,16 +2,19 @@ from prism import Flag, Command from prism.command import CommandArc -fn say(command: CommandArc, args: List[String]) raises -> None: +fn say(command: CommandArc, args: List[String]) -> Error: print("Shouldn't be here!") + return Error() -fn say_hello(command: CommandArc, args: List[String]) raises -> None: +fn say_hello(command: CommandArc, args: List[String]) -> Error: print("Hello World!") + return Error() -fn say_goodbye(command: CommandArc, args: List[String]) raises -> None: +fn say_goodbye(command: CommandArc, args: List[String]) -> Error: print("Goodbye World!") + return Error() # for some reason returning the command object without setting it to variable breaks the compiler diff --git a/examples/logging/root.mojo b/examples/logging/root.mojo index 3e1c12b..f4e86a9 100644 --- a/examples/logging/root.mojo +++ b/examples/logging/root.mojo @@ -2,7 +2,7 @@ from prism import Flag, Command, CommandArc, minimum_n_args from examples.logging.log import logger, default_logger, json_logger -fn handler(command: CommandArc, args: List[String]) raises -> None: +fn handler(command: CommandArc, args: List[String]) -> Error: var print_type = command[].get_all_flags()[].get_as_string("type").value() if print_type == "json": for arg in args: @@ -14,8 +14,10 @@ fn handler(command: CommandArc, args: List[String]) raises -> None: for arg in args: default_logger.info(arg[]) + return Error() -fn init() raises -> None: + +fn init() -> None: var root_command = Command( name="logger", description="Base command.", run=handler, arg_validator=minimum_n_args[1]() ) @@ -26,5 +28,5 @@ fn init() raises -> None: root_command.execute() -fn main() raises -> None: +fn main() -> None: init() diff --git a/examples/nested/nested.mojo b/examples/nested/nested.mojo index c783049..7f23130 100644 --- a/examples/nested/nested.mojo +++ b/examples/nested/nested.mojo @@ -2,55 +2,68 @@ from prism import Flag, Command, CommandArc from python import Python, PythonObject -fn base(command: CommandArc, args: List[String]) raises -> None: +fn base(command: CommandArc, args: List[String]) -> Error: print("This is the base command!") + return Error() -fn print_information(command: CommandArc, args: List[String]) raises -> None: +fn print_information(command: CommandArc, args: List[String]) -> Error: print("Pass cat or dog as a subcommand, and see what you get!") + return Error() -fn get_cat_fact(command: CommandArc, args: List[String]) raises -> None: +fn get_cat_fact(command: CommandArc, args: List[String]) -> Error: var flags = command[].get_all_flags()[] var lover = flags.get_as_bool("lover") if lover and lover.value(): print("Hello fellow cat lover!") - var requests = Python.import_module("requests") - # URL you want to send a GET request to - var url = "https://cat-fact.herokuapp.com/facts/" + try: + var requests = Python.import_module("requests") - # Send the GET request - var response = requests.get(url) + # URL you want to send a GET request to + var url = "https://cat-fact.herokuapp.com/facts/" - # Check if the request was successful (status code 200) - if response.status_code == 200: - var count = flags.get_as_int("count") - if not count: - raise Error("Count flag was not found.") - var body = response.json() - for i in range(count.value()): - print(body[i]["text"]) - else: - raise Error("Request failed!") + # Send the GET request + var response = requests.get(url) + # Check if the request was successful (status code 200) + if response.status_code == 200: + var count = flags.get_as_int("count") + if not count: + return Error("Count flag was not found.") + var body = response.json() + for i in range(count.value()): + print(body[i]["text"]) + else: + return Error("Request failed!") + except e: + return e -fn get_dog_breeds(command: CommandArc, args: List[String]) raises -> None: - var requests = Python.import_module("requests") - # URL you want to send a GET request to - var url = "https://dog.ceo/api/breeds/list/all" + return Error() - # Send the GET request - var response = requests.get(url) - # Check if the request was successful (status code 200) - if response.status_code == 200: - print(response.json()["message"]) - else: - raise Error("Request failed!") +fn get_dog_breeds(command: CommandArc, args: List[String]) -> Error: + try: + var requests = Python.import_module("requests") + # URL you want to send a GET request to + var url = "https://dog.ceo/api/breeds/list/all" + # Send the GET request + var response = requests.get(url) -fn init() raises -> None: + # Check if the request was successful (status code 200) + if response.status_code == 200: + print(response.json()["message"]) + else: + return Error("Request failed!") + except e: + return e + + return Error() + + +fn init() -> None: var root_command = Command(name="nested", description="Base command.", run=base) var get_command = Command( @@ -79,5 +92,5 @@ fn init() raises -> None: root_command.execute() -fn main() raises -> None: +fn main() -> None: init() diff --git a/examples/printer/printer.mojo b/examples/printer/printer.mojo index 178467a..0f41522 100644 --- a/examples/printer/printer.mojo +++ b/examples/printer/printer.mojo @@ -3,10 +3,10 @@ from prism import Flag, Command, CommandArc, exact_args from external.mist import TerminalStyle -fn printer(command: CommandArc, args: List[String]) raises -> None: +fn printer(command: CommandArc, args: List[String]) -> Error: if len(args) <= 0: print("No text to print! Pass in some text as a positional argument.") - return None + return Error() var flags = command[].get_all_flags()[] var color = flags.get_as_string("color") @@ -28,17 +28,20 @@ fn printer(command: CommandArc, args: List[String]) raises -> None: style = style.italic() print(style.render(args[0])) + return Error() -fn pre_hook(command: CommandArc, args: List[String]) raises -> None: +fn pre_hook(command: CommandArc, args: List[String]) -> Error: print("Pre-hook executed!") + return Error() -fn post_hook(command: CommandArc, args: List[String]) raises -> None: +fn post_hook(command: CommandArc, args: List[String]) -> Error: print("Post-hook executed!") + return Error() -fn init() raises -> None: +fn init() -> None: var start = now() var root_command = Command.new[ name="printer", @@ -54,5 +57,5 @@ fn init() raises -> None: print("duration", (now() - start) / 1e9) -fn main() raises -> None: +fn main() -> None: init() diff --git a/examples/read_csv/root.mojo b/examples/read_csv/root.mojo index 531f2d7..93228cf 100644 --- a/examples/read_csv/root.mojo +++ b/examples/read_csv/root.mojo @@ -3,20 +3,25 @@ from external.goodies import CSVReader, FileWrapper from os.path import exists -fn handler(command: CommandArc, args: List[String]) raises -> None: +fn handler(command: CommandArc, args: List[String]) -> Error: var file_path = command[].get_all_flags()[].get_as_string("file").value() if not exists(file_path): - raise Error("File does not exist.") + return Error("File does not exist.") - var file = FileWrapper(file_path, "r") - var reader = CSVReader(file^) - var lines = command[].get_all_flags()[].get_as_int("lines").value() - var csv = reader.read_lines(lines, "\n", 3) - for i in range(csv.row_count()): - print(csv.get(i, 0)) + try: + var file = FileWrapper(file_path, "r") + var reader = CSVReader(file^) + var lines = command[].get_all_flags()[].get_as_int("lines").value() + var csv = reader.read_lines(lines, "\n", 3) + for i in range(csv.row_count()): + print(csv.get(i, 0)) + except e: + return e + return Error() -fn init() raises -> None: + +fn init() -> None: var root_command = Command(name="read_csv", description="Base command.", run=handler, arg_validator=no_args) root_command.flags.add_string_flag[name="file", shorthand="f", usage="CSV file to read."]() root_command.flags.add_int_flag[name="lines", shorthand="l", usage="Lines to print.", default=3]() @@ -24,5 +29,5 @@ fn init() raises -> None: root_command.execute() -fn main() raises -> None: +fn main() -> None: init() diff --git a/prism/__init__.mojo b/prism/__init__.mojo index 929dc70..ceb48a8 100644 --- a/prism/__init__.mojo +++ b/prism/__init__.mojo @@ -4,13 +4,5 @@ from .command import ( CommandFunction, ArgValidator, ) -from .args import ( - no_args, - valid_args, - arbitrary_args, - minimum_n_args, - maximum_n_args, - exact_args, - range_args, -) +from .args import no_args, valid_args, arbitrary_args, minimum_n_args, maximum_n_args, exact_args, range_args, match_all from .flag import Flag, FlagSet diff --git a/prism/args.mojo b/prism/args.mojo index 7b70e0a..70c5089 100644 --- a/prism/args.mojo +++ b/prism/args.mojo @@ -19,7 +19,7 @@ fn no_args(args: List[String]) -> Optional[String]: args: The arguments to check. """ if len(args) > 0: - return String("Command ") + String("does not take any arguments") + return String("Command does not take any arguments.") return None @@ -45,7 +45,7 @@ fn minimum_n_args[n: Int]() -> ArgValidator: fn less_than_n_args(args: List[String]) -> Optional[String]: if len(args) < n: return sprintf( - "Command accepts at least %d arguments. Received: %d.", + "Command accepts at least %d argument(s). Received: %d.", n, len(args), ) @@ -66,7 +66,7 @@ fn maximum_n_args[n: Int]() -> ArgValidator: fn more_than_n_args(args: List[String]) -> Optional[String]: if len(args) > n: - return sprintf("Command accepts at most %d arguments. Received: %d", n, len(args)) + return sprintf("Command accepts at most %d argument(s). Received: %d.", n, len(args)) return None return more_than_n_args @@ -84,7 +84,7 @@ fn exact_args[n: Int]() -> ArgValidator: fn exactly_n_args(args: List[String]) -> Optional[String]: if len(args) != n: - return sprintf("Command accepts at exactly %d arguments. Received: %d", n, len(args)) + return sprintf("Command accepts exactly %d argument(s). Received: %d.", n, len(args)) return None return exactly_n_args @@ -121,7 +121,7 @@ fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: fn range_n_args(args: List[String]) -> Optional[String]: if len(args) < minimum or len(args) > maximum: return sprintf( - "Command accepts between %d and %d arguments. Received: %d", + "Command accepts between %d to %d argument(s). Received: %d.", minimum, maximum, len(args), diff --git a/prism/command.mojo b/prism/command.mojo index e8bebda..dfecd3f 100644 --- a/prism/command.mojo +++ b/prism/command.mojo @@ -3,6 +3,7 @@ from collections.optional import Optional from collections.dict import Dict, KeyElement from memory._arc import Arc from external.gojo.fmt import sprintf +from external.gojo.builtins import panic from .flag import Flag, FlagSet, get_flags from .args import arbitrary_args, ArgValidator, get_args from .vector import join, to_string, contains @@ -21,7 +22,30 @@ fn get_args_as_list() -> List[String]: alias CommandArc = Arc[Command] -alias CommandFunction = fn (command: Arc[Command], args: List[String]) raises -> None +alias CommandFunction = fn (command: Arc[Command], args: List[String]) -> Error + + +fn parse_command_from_args(start: Command) -> (Command, List[String]): + var args = get_args_as_list() + var number_of_args = len(args) + var command = start + var children = command.children + var leftover_args_start_index = 0 # Start at 1 to start slice at the first remaining arg, not the last child command. + + for arg in args: + for command_ref in children: + if command_ref[][].name == arg[]: + command = command_ref[][] + children = command.children + leftover_args_start_index += 1 + break + + # If the there are more or equivalent args to the index, then there are remaining args to pass to the command. + var remaining_args = List[String]() + if number_of_args >= leftover_args_start_index: + remaining_args = args[leftover_args_start_index:number_of_args] + + return command, remaining_args # TODO: Add persistent flags @@ -47,7 +71,6 @@ struct Command(CollectionElement): name: String, description: String, run: CommandFunction, - # arg_validator: ArgValidator = arbitrary_args, valid_args: List[String] = List[String](), pre_run: Optional[CommandFunction] = None, post_run: Optional[CommandFunction] = None, @@ -122,7 +145,38 @@ struct Command(CollectionElement): name, description, run, - arg_validator, + valid_args, + pre_run, + post_run, + ) + + @staticmethod + fn new[ + name: String, + description: String, + run: CommandFunction, + valid_args: List[String] = List[String](), + pre_run: Optional[CommandFunction] = None, + post_run: Optional[CommandFunction] = None, + ]() -> Self: + """Experimental function to create a new Command by using parameters to offload some work to compile time. + + Params: + name: The name of the command. + description: The description of the command. + run: The function to run when the command is executed. + valid_args: The valid arguments for the command. + pre_run: The function to run before the command is executed. + post_run: The function to run after the command is executed. + + Returns: + A new Command instance. + """ + return Command( + name, + description, + run, + arbitrary_args, valid_args, pre_run, post_run, @@ -178,15 +232,15 @@ struct Command(CollectionElement): + parent_name ) - fn full_command(self) -> String: + fn _full_command(self) -> String: """Traverses up the parent command tree to build the full command as a string.""" if self.parent[]: - var ancestor: String = self.parent[].value().full_command() + var ancestor: String = self.parent[].value()._full_command() return ancestor + " " + self.name else: return self.name - fn help(self) -> None: + fn _help(self) -> None: """Prints the help information for the command.""" var child_commands: String = "" for child in self.children: @@ -214,78 +268,55 @@ struct Command(CollectionElement): if len(self.flags) > 0: usage_arguments = usage_arguments + " [flags]" - var full_command = self.full_command() + var _full_command = self._full_command() var help = self.description + "\n\n" - var usage = "Usage:\n" + " " + full_command + usage_arguments + "\n\n" + var usage = "Usage:\n" + " " + _full_command + usage_arguments + "\n\n" var available_commands = "Available commands:\n" + child_commands + "\n" var available_flags = "Available flags:\n" + flags + "\n" - var note = 'Use "' + full_command + ' [command] --help" for more information about a command.' + var note = 'Use "' + _full_command + ' [command] --help" for more information about a command.' help = help + usage + available_commands + available_flags + note print(help) - fn validate_flag_set(self, flag_set: FlagSet) raises -> None: - """Validates the flags passed to the command. Raises an error if an invalid flag is passed. - - Args: - flag_set: The flags passed to the command. - """ - var length_of_command_flags = len(self.flags) - var length_of_input_flags = len(flag_set) - - if length_of_input_flags > length_of_command_flags: - raise Error("Specified more flags than the command accepts, please check your command's flags.") - - for flag in flag_set.flags: - if flag[] not in self.flags: - raise Error(String("Invalid flags passed to command: ") + flag[].name) - - fn execute(inout self) raises -> None: + fn execute(inout self) -> None: """Traverses the arguments passed to the executable and executes the last command in the branch.""" # Traverse from the root command through the children to find a match for the current argument. # Any additional arguments past the last matched command name are considered arguments. # TODO: Tree traversal is new to me, there's probably a better way to do this. - var args = get_args_as_list() - var number_of_args = len(args) - var command = self - var children = command.children - var leftover_args_start_index = 0 # Start at 1 to start slice at the first remaining arg, not the last child command. - - for arg in args: - for command_ref in children: - if command_ref[][].name == arg[]: - command = command_ref[][] - children = command.children - leftover_args_start_index += 1 - break - - # If the there are more or equivalent args to the index, then there are remaining args to pass to the command. - var remaining_args = List[String]() - if number_of_args >= leftover_args_start_index: - remaining_args = args[leftover_args_start_index:number_of_args] + var remaining_args: List[String] + var command: Self + command, remaining_args = parse_command_from_args(self) # Get the flags for the command to be executed. - remaining_args = get_flags(command.flags, remaining_args) + var err: Error + remaining_args, err = get_flags(command.flags, remaining_args) + if err: + panic(err) # Check if the help flag was passed var help = command.flags.get_as_bool("help") if help.value() == True: - command.help() + command._help() return None # Validate the remaining arguments var error_message = self.arg_validator(remaining_args) if error_message: - raise Error(error_message.value()) - - # Check if the flags are valid - command.validate_flag_set(command.flags) + panic(error_message.value()) # Run the function's commands. if command.pre_run: - command.pre_run.value()(Arc(command), remaining_args) - command.run(Arc(command), remaining_args) + err = command.pre_run.value()(Arc(command), remaining_args) + if err: + panic(err) + + err = command.run(Arc(command), remaining_args) + if err: + panic(err) + if command.post_run: - command.post_run.value()(Arc(command), remaining_args) + err = command.post_run.value()(Arc(command), remaining_args) + if err: + panic(err) fn get_all_flags(self) -> Arc[FlagSet]: """Returns all flags for the command and persistent flags from its parent. @@ -295,14 +326,6 @@ struct Command(CollectionElement): """ return Arc(self.flags) - fn set_parent(inout self, inout parent: Command) -> None: - """Sets the command's parent attribute to the given parent. - - Args: - parent: The name of the parent command. - """ - self.parent[] = parent - fn add_command(inout self, inout command: Command): """Adds child command and set's child's parent attribute to self. @@ -310,4 +333,4 @@ struct Command(CollectionElement): command: The command to add as a child of self. """ self.children.append(Arc(command)) - command.set_parent(self) + command.parent[] = self diff --git a/prism/flag.mojo b/prism/flag.mojo index a55e53c..893a9f6 100644 --- a/prism/flag.mojo +++ b/prism/flag.mojo @@ -668,7 +668,7 @@ fn parse_shorthand_flag( # TODO: This parsing is dirty atm, will come back around and clean it up. -fn get_flags(inout flags: FlagSet, arguments: List[String]) raises -> List[String]: +fn get_flags(inout flags: FlagSet, arguments: List[String]) -> (List[String], Error): """Parses flags and args from the args passed via the command line and adds them to their appropriate collections. Args: @@ -690,15 +690,19 @@ fn get_flags(inout flags: FlagSet, arguments: List[String]) raises -> List[Strin var value: String = "" var increment_by: Int = 0 - # Full flag - if argument.startswith("--", 0, 2): - name, value, increment_by = parse_flag(i, argument, arguments, flags) + try: + # Full flag + if argument.startswith("--", 0, 2): + name, value, increment_by = parse_flag(i, argument, arguments, flags) + + # Shorthand flag + elif argument.startswith("-", 0, 1): + name, value, increment_by = parse_shorthand_flag(i, argument, arguments, flags) - # Shorthand flag - elif argument.startswith("-", 0, 1): - name, value, increment_by = parse_shorthand_flag(i, argument, arguments, flags) + flags._set_flag_value(name, value) + except e: + return remaining_args, e - flags._set_flag_value(name, value) i += increment_by - return remaining_args + return remaining_args, Error() diff --git a/tests/__init__.mojo b/tests/__init__.mojo new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_args.mojo b/tests/test_args.mojo new file mode 100644 index 0000000..cf8bb63 --- /dev/null +++ b/tests/test_args.mojo @@ -0,0 +1,74 @@ +from tests.wrapper import MojoTest +from prism.args import ( + no_args, + valid_args, + arbitrary_args, + minimum_n_args, + maximum_n_args, + exact_args, + range_args, + match_all, + ArgValidator, +) + + +fn test_no_args(): + var test = MojoTest("Testing args.no_args") + var result = no_args(List[String]("abc")) + test.assert_equal(result.value(), String("Command does not take any arguments.")) + + +fn test_valid_args(): + var test = MojoTest("Testing args.valid_args") + var result = valid_args[List[String]("Pineapple")]()(List[String]("abc")) + test.assert_equal(result.value(), "Invalid argument abc for command.") + + +fn test_arbitrary_args(): + var test = MojoTest("Testing args.arbitrary_args") + var result = arbitrary_args(List[String]("abc", "blah", "blah")) + + # If the result is anything but None, fail the test. + if result is not None: + test.assert_false(True) + + +fn test_minimum_n_args(): + var test = MojoTest("Testing args.minimum_n_args") + var result = minimum_n_args[3]()(List[String]("abc", "123")) + test.assert_equal(result.value(), "Command accepts at least 3 argument(s). Received: 2.") + + +fn test_maximum_n_args(): + var test = MojoTest("Testing args.maximum_n_args") + var result = maximum_n_args[1]()(List[String]("abc", "123")) + test.assert_equal(result.value(), "Command accepts at most 1 argument(s). Received: 2.") + + +fn test_exact_args(): + var test = MojoTest("Testing args.exact_args") + var result = exact_args[1]()(List[String]("abc", "123")) + test.assert_equal(result.value(), "Command accepts exactly 1 argument(s). Received: 2.") + + +fn test_range_args(): + var test = MojoTest("Testing args.range_args") + var result = range_args[0, 1]()(List[String]("abc", "123")) + test.assert_equal(result.value(), "Command accepts between 0 to 1 argument(s). Received: 2.") + + +# fn test_match_all(): +# var test = MojoTest("Testing args.match_all") +# var result = match_all[List[ArgValidator](range_args[0, 1](), valid_args[List[String]("Pineapple")]())]()(List[String]("abc", "123")) +# test.assert_equal(result.value(), "Command accepts between 0 to 1 argument(s). Received: 2.") + + +fn main(): + test_no_args() + test_valid_args() + test_arbitrary_args() + test_minimum_n_args() + test_maximum_n_args() + test_exact_args() + test_range_args() + # test_match_all() diff --git a/tests/test_command.mojo b/tests/test_command.mojo new file mode 100644 index 0000000..2e7d775 --- /dev/null +++ b/tests/test_command.mojo @@ -0,0 +1,32 @@ +from tests.wrapper import MojoTest +from prism.command import Command, CommandArc +from prism.flag import FlagSet + + +fn test_command_operations(): + var test = MojoTest("Testing Command.new") + + fn dummy(command: CommandArc, args: List[String]) -> Error: + return Error() + + var cmd = Command.new[name="root", description="Base command.", run=dummy]() + + var get_all_flags_test = MojoTest("Testing Command.get_all_flags") + var flags = cmd.get_all_flags()[] + for flag in flags.flags: + get_all_flags_test.assert_equal("help", flag[].name) + + var add_command_test = MojoTest("Testing Command.add_command") + var child_cmd = Command.new[name="child", description="Child command.", run=dummy]() + cmd.add_command(child_cmd) + child_cmd.flags.add_string_flag[name="color", shorthand="c", usage="Text color", default="#3464eb"]() + + var full_command_test = MojoTest("Testing Command._full_command") + full_command_test.assert_equal(child_cmd._full_command(), "root child") + + var help_test = MojoTest("Testing Command._help") + cmd._help() + + +fn main(): + test_command_operations() diff --git a/tests/test_flags.mojo b/tests/test_flags.mojo new file mode 100644 index 0000000..c1b8acf --- /dev/null +++ b/tests/test_flags.mojo @@ -0,0 +1,35 @@ +from tests.wrapper import MojoTest +from prism.flag import string_to_bool, string_to_float, Flag, FlagSet, get_flags, parse_flag, parse_shorthand_flag + + +fn test_string_to_bool(): + pass + + +fn test_string_to_float(): + pass + + +fn test_get_flags(): + pass + + +fn test_parse_flag(): + pass + + +fn test_parse_shorthand_flag(): + pass + + +fn test_flag_equals(): + pass + + +fn main(): + test_string_to_bool() + test_string_to_float() + test_get_flags() + test_parse_flag() + test_parse_shorthand_flag() + test_flag_equals() diff --git a/tests/wrapper.mojo b/tests/wrapper.mojo new file mode 100644 index 0000000..5bbffe9 --- /dev/null +++ b/tests/wrapper.mojo @@ -0,0 +1,44 @@ +from testing import testing + + +@value +struct MojoTest: + """ + A utility struct for testing. + """ + + var test_name: String + + fn __init__(inout self, test_name: String): + self.test_name = test_name + print("# " + test_name) + + fn assert_true(self, cond: Bool, message: String = ""): + try: + if message == "": + testing.assert_true(cond) + else: + testing.assert_true(cond, message) + except e: + print(e) + + fn assert_false(self, cond: Bool, message: String = ""): + try: + if message == "": + testing.assert_false(cond) + else: + testing.assert_false(cond, message) + except e: + print(e) + + fn assert_equal(self, left: Int, right: Int): + try: + testing.assert_equal(left, right) + except e: + print(e) + + fn assert_equal(self, left: String, right: String): + try: + testing.assert_equal(left, right) + except e: + print(e) From 86793b6c7736332f4212b5126cac49b8cb9ed777 Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Sat, 20 Apr 2024 18:46:26 -0500 Subject: [PATCH 2/2] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index a0d7b20..e75d277 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ A Budding CLI Library! Inspired by: `Cobra`! +NOTE: This does not work on Mojo 24.2, you must use the nightly build for now. This will be resolved in the next Mojo release. + ## Usage WIP: Documentation, but you should be able to figure out how to use the library by looking at the examples and referencing the Cobra documentation. You should be able to build the package by running `mojo package prism -I .`.