From 18e03733bc1591efcf846f7384e5272250e32d66 Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Wed, 9 Oct 2024 11:34:17 -0500 Subject: [PATCH 1/4] refactor func names --- README.md | 30 +-- examples/aliases.mojo | 4 +- examples/arg_validators.mojo | 9 +- examples/chromeria.mojo | 4 +- examples/fg_child.mojo | 16 +- examples/fg_parent.mojo | 21 +- examples/hello_world.mojo | 20 +- examples/printer/printer.mojo | 22 +- examples/requests/nested.mojo | 18 +- examples/requests/persistent_flags.mojo | 26 +-- src/prism/args.mojo | 60 +++--- src/prism/command.mojo | 56 ++--- src/prism/flag.mojo | 2 + src/prism/flag_parser.mojo | 4 +- src/prism/flag_set.mojo | 267 ++++++++++-------------- src/prism/transform.mojo | 56 ++--- test/test_command.mojo | 4 +- test/test_flags.mojo | 16 +- 18 files changed, 304 insertions(+), 331 deletions(-) diff --git a/README.md b/README.md index c024a63..acc884d 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Commands can have typed flags added to them to enable different behaviors. var root = Command( name="logger", description="Base command.", run=handler ) - root.flags.add_string_flag(name="type", shorthand="t", usage="Formatting type: [json, custom]") + root.flags.string_flag(name="type", shorthand="t", usage="Formatting type: [json, custom]") ``` ![Logging](https://github.com/thatstoasty/prism/blob/main/doc/tapes/logging.gif) @@ -120,7 +120,7 @@ fn main() -> None: persistent_pre_run=pre_hook, persistent_post_run=post_hook, )) - get_command[].flags.persistent_flags.add_bool_flag(name="lover", shorthand="l", usage="Are you an animal lover?") + get_command[].flags.persistent_flags.bool_flag(name="lover", shorthand="l", usage="Are you an animal lover?") ``` ![Persistent](https://github.com/thatstoasty/prism/blob/main/doc/tapes/persistent.gif) @@ -135,7 +135,7 @@ By default flags are considered optional. If you want your command to report an var print_tool = Arc(Command( name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") )) - print_tool[].flags.add_bool_flag(name="required", shorthand="r", usage="Always required.") + print_tool[].flags.bool_flag(name="required", shorthand="r", usage="Always required.") print_tool[].mark_flag_required("required") ``` @@ -147,7 +147,7 @@ Same for persistent flags: description="This is a dummy command!", run=test, ) - root.persistent_flags.add_bool_flag(name="free", shorthand="f", usage="Always required.") + root.persistent_flags.bool_flag(name="free", shorthand="f", usage="Always required.") root.mark_persistent_flag_required("free") ``` @@ -159,8 +159,8 @@ If you have different flags that must be provided together (e.g. if they provide var print_tool = Arc(Command( name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") )) - print_tool[].flags.add_string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") - print_tool[].flags.add_string_flag(name="formatting", shorthand="f", usage="Text formatting") + print_tool[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + print_tool[].flags.string_flag(name="formatting", shorthand="f", usage="Text formatting") print_tool[].mark_flags_required_together("color", "formatting") ``` @@ -170,8 +170,8 @@ You can also prevent different flags from being provided together if they repres var print_tool = Arc(Command( name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") )) - print_tool[].add_string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") - print_tool[].add_string_flag(name="hue", shorthand="x", usage="Text color", default="#3464eb") + print_tool[].string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + print_tool[].string_flag(name="hue", shorthand="x", usage="Text color", default="#3464eb") print_tool[].mark_flags_mutually_exclusive("color", "hue") ``` @@ -181,8 +181,8 @@ If you want to require at least one flag from a group to be present, you can use var print_tool = Arc(Command( name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") )) - print_tool[].flags.add_string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") - print_tool[].flags.add_string_flag(name="formatting", shorthand="f", usage="Text formatting") + print_tool[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + print_tool[].flags.string_flag(name="formatting", shorthand="f", usage="Text formatting") print_tool[].mark_flags_one_required("color", "formatting") print_tool[].mark_flags_mutually_exclusive("color", "formatting") ``` @@ -208,16 +208,16 @@ fn main() -> None: run=test, ) # Persistent flags are defined on the parent command. - root.persistent_flags.add_bool_flag(name="required", shorthand="r", usage="Always required.") - root.persistent_flags.add_string_flag(name="host", shorthand="h", usage="Host") - root.persistent_flags.add_string_flag(name="port", shorthand="p", usage="Port") + root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") + root.persistent_flags.string_flag(name="host", shorthand="h", usage="Host") + root.persistent_flags.string_flag(name="port", shorthand="p", usage="Port") root.mark_persistent_flag_required("required") var print_tool = Arc(Command( name="tool", description="This is a dummy command!", run=tool_func )) - print_tool[].flags.add_bool_flag(name="also", shorthand="a", usage="Also always required.") - print_tool[].flags.add_string_flag(name="uri", shorthand="u", usage="URI") + print_tool[].flags.bool_flag(name="also", shorthand="a", usage="Also always required.") + print_tool[].flags.string_flag(name="uri", shorthand="u", usage="URI") # Child commands are added to the parent command. root.add_subcommand(print_tool) diff --git a/examples/aliases.mojo b/examples/aliases.mojo index 0c4127a..4518303 100644 --- a/examples/aliases.mojo +++ b/examples/aliases.mojo @@ -2,11 +2,11 @@ from memory import Arc from prism import Context, Command -fn test(context: Context) -> None: +fn test(ctx: Context) -> None: print("Pass tool, object, or thing as a subcommand!") -fn tool_func(context: Context) -> None: +fn tool_func(ctx: Context) -> None: print("My tool!") diff --git a/examples/arg_validators.mojo b/examples/arg_validators.mojo index a4a81e8..66d1397 100644 --- a/examples/arg_validators.mojo +++ b/examples/arg_validators.mojo @@ -2,7 +2,6 @@ from memory import Arc from prism import ( Command, Context, - CommandArc, no_args, valid_args, minimum_n_args, @@ -12,13 +11,13 @@ from prism import ( ) -fn test(context: Context) -> None: - for arg in context.args: +fn test(ctx: Context) -> None: + for arg in ctx.args: print("Received", arg[]) -fn hello(context: Context) -> None: - print(context.command[].name, "Hello from Chromeria!") +fn hello(ctx: Context) -> None: + print(ctx.command[].name, "Hello from Chromeria!") fn main() -> None: diff --git a/examples/chromeria.mojo b/examples/chromeria.mojo index fe53022..3aeffb8 100644 --- a/examples/chromeria.mojo +++ b/examples/chromeria.mojo @@ -2,11 +2,11 @@ from memory import Arc from prism import Command, Context -fn test(context: Context) -> None: +fn test(ctx: Context) -> None: print("Pass chromeria as a subcommand!") -fn hello(context: Context) -> None: +fn hello(ctx: Context) -> None: print("Hello from Chromeria!") diff --git a/examples/fg_child.mojo b/examples/fg_child.mojo index c55221c..f9f7eb3 100644 --- a/examples/fg_child.mojo +++ b/examples/fg_child.mojo @@ -1,12 +1,12 @@ from memory import Arc -from prism import Command, Context, CommandArc +from prism import Command, Context -fn test(context: Context) -> None: +fn test(ctx: Context) -> None: print("Pass tool, object, or thing as a subcommand!") -fn tool_func(context: Context) -> None: +fn tool_func(ctx: Context) -> None: print("My tool!") @@ -16,14 +16,14 @@ fn main() -> None: description="This is a dummy command!", run=test, ) - root.persistent_flags.add_bool_flag(name="required", shorthand="r", usage="Always required.") - root.persistent_flags.add_string_flag(name="host", shorthand="h", usage="Host") - root.persistent_flags.add_string_flag(name="port", shorthand="p", usage="Port") + root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") + root.persistent_flags.string_flag(name="host", shorthand="h", usage="Host") + root.persistent_flags.string_flag(name="port", shorthand="p", usage="Port") root.mark_persistent_flag_required("required") var print_tool = Arc(Command(name="tool", description="This is a dummy command!", run=tool_func)) - print_tool[].flags.add_bool_flag(name="also", shorthand="a", usage="Also always required.") - print_tool[].flags.add_string_flag(name="uri", shorthand="u", usage="URI") + print_tool[].flags.bool_flag(name="also", shorthand="a", usage="Also always required.") + print_tool[].flags.string_flag(name="uri", shorthand="u", usage="URI") root.add_subcommand(print_tool) # Make sure to add the child command to the parent before marking flags. diff --git a/examples/fg_parent.mojo b/examples/fg_parent.mojo index 5e976da..06cf794 100644 --- a/examples/fg_parent.mojo +++ b/examples/fg_parent.mojo @@ -2,11 +2,18 @@ from memory import Arc from prism import Command, Context -fn test(context: Context) -> None: - print("Pass tool, object, or thing as a subcommand!") +fn test(ctx: Context) -> None: + var host = ctx.command[].flags.get_string("host") + var port = ctx.command[].flags.get_string("port") + var uri = ctx.command[].flags.get_string("uri") + if uri: + print("URI: ", uri.value()) + else: + print(host.value(), ":", port.value()) -fn tool_func(context: Context) -> None: + +fn tool_func(ctx: Context) -> None: print("My tool!") @@ -16,10 +23,10 @@ fn main() -> None: description="This is a dummy command!", run=test, ) - root.persistent_flags.add_bool_flag(name="required", shorthand="r", usage="Always required.") - root.persistent_flags.add_string_flag(name="host", shorthand="h", usage="Host") - root.persistent_flags.add_string_flag(name="port", shorthand="p", usage="Port") - root.persistent_flags.add_string_flag(name="uri", shorthand="u", usage="URI") + root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") + root.persistent_flags.string_flag(name="host", shorthand="h", usage="Host") + root.persistent_flags.string_flag(name="port", shorthand="p", usage="Port") + root.persistent_flags.string_flag(name="uri", shorthand="u", usage="URI") root.mark_flags_required_together("host", "port") root.mark_flags_mutually_exclusive("host", "uri") root.mark_flag_required("required") diff --git a/examples/hello_world.mojo b/examples/hello_world.mojo index 35ff927..c643fd4 100644 --- a/examples/hello_world.mojo +++ b/examples/hello_world.mojo @@ -2,12 +2,12 @@ from memory import Arc from prism import Command, Context -fn printer(context: Context) -> None: - if len(context.args) == 0: +fn printer(ctx: Context) -> None: + if len(ctx.args) == 0: print("No args provided.") return - print(context.args[0]) + print(ctx.args[0]) return @@ -22,17 +22,17 @@ fn build_printer_command() -> Arc[Command]: return cmd -fn say(context: Context) -> None: +fn say(ctx: Context) -> None: print("Shouldn't be here!") return None -fn say_hello(context: Context) -> None: +fn say_hello(ctx: Context) -> None: print("Hello World!") return None -fn say_goodbye(context: Context) -> None: +fn say_goodbye(ctx: Context) -> None: print("Goodbye World!") return None @@ -70,9 +70,9 @@ fn build_goodbye_command() -> Arc[Command]: return cmd -fn test(context: Context) -> None: - print(context.command[].flags.get_as_string("env").value()) - for item in context.command[].flags.flags: +fn test(ctx: Context) -> None: + print(ctx.command[].flags.get_string("env").value()) + for item in ctx.command[].flags.flags: if item[].value: print(item[].name, item[].value.value()) else: @@ -87,7 +87,7 @@ fn main() -> None: description="This is a dummy command!", run=test, ) - root.flags.add_string_flag(name="env", shorthand="e", usage="Environment.") + root.flags.string_flag(name="env", shorthand="e", usage="Environment.", default="") var say_command = build_say_command() var hello_command = build_hello_command() diff --git a/examples/printer/printer.mojo b/examples/printer/printer.mojo index f1f4130..e3cc105 100644 --- a/examples/printer/printer.mojo +++ b/examples/printer/printer.mojo @@ -1,15 +1,15 @@ from memory import Arc -from prism import Command, Context, CommandArc, exact_args +from prism import Command, Context, exact_args from mist import Style -fn printer(context: Context) -> None: - if len(context.args) <= 0: +fn printer(ctx: Context) -> None: + if len(ctx.args) <= 0: print("No text to print! Pass in some text as a positional argument.") return None - var color = context.command[].flags.get_as_uint32("color") - var formatting = context.command[].flags.get_as_string("formatting") + var color = ctx.command[].flags.get_uint32("color") + var formatting = ctx.command[].flags.get_string("formatting") var style = Style() if not color: @@ -22,7 +22,7 @@ fn printer(context: Context) -> None: var formatting_value = formatting.or_else("") if formatting_value == "": - print(style.render(context.args[0])) + print(style.render(ctx.args[0])) return None if formatting.value() == "bold": @@ -32,16 +32,16 @@ fn printer(context: Context) -> None: elif formatting.value() == "italic": style = style.italic() - print(style.render(context.args[0])) + print(style.render(ctx.args[0])) return None -fn pre_hook(context: Context) -> None: +fn pre_hook(ctx: Context) -> None: print("Pre-hook executed!") return None -fn post_hook(context: Context) -> None: +fn post_hook(ctx: Context) -> None: print("Post-hook executed!") return None @@ -56,7 +56,7 @@ fn main() -> None: ) root.arg_validator = exact_args[1]() - root.flags.add_uint32_flag(name="color", shorthand="c", usage="Text color", default=0x3464EB) - root.flags.add_string_flag(name="formatting", shorthand="f", usage="Text formatting") + root.flags.uint32_flag(name="color", shorthand="c", usage="Text color", default=0x3464EB) + root.flags.string_flag(name="formatting", shorthand="f", usage="Text formatting") root.execute() diff --git a/examples/requests/nested.mojo b/examples/requests/nested.mojo index 1096be8..0be5fbf 100644 --- a/examples/requests/nested.mojo +++ b/examples/requests/nested.mojo @@ -3,19 +3,19 @@ from prism import Command, Context from python import Python -fn base(context: Context) -> None: +fn base(ctx: Context) -> None: print("This is the base command!") return None -fn print_information(context: Context) -> None: +fn print_information(ctx: Context) -> None: print("Pass cat or dog as a subcommand, and see what you get!") return None -fn get_cat_fact(context: Context) raises -> None: - var flags = context.command[].flags - var lover = flags.get_as_bool("lover") +fn get_cat_fact(ctx: Context) raises -> None: + var flags = ctx.command[].flags + var lover = flags.get_bool("lover") if lover and lover.value(): print("Hello fellow cat lover!") @@ -25,7 +25,7 @@ fn get_cat_fact(context: Context) raises -> None: var url = "https://catfact.ninja/fact" # Send the GET requests - var count = flags.get_as_int("count") + var count = flags.get_int("count") if not count: raise Error("Count flag was not found.") @@ -40,7 +40,7 @@ fn get_cat_fact(context: Context) raises -> None: raise Error("Request failed!") -fn get_dog_breeds(context: Context) raises -> None: +fn get_dog_breeds(ctx: Context) raises -> None: var requests = Python.import_module("requests") # URL you want to send a GET request to var url = "https://dog.ceo/api/breeds/list/all" @@ -74,8 +74,8 @@ fn main() -> None: erroring_run=get_cat_fact, ) ) - cat_command[].flags.add_int_flag(name="count", shorthand="c", usage="Number of facts to get.", default=1) - cat_command[].flags.add_bool_flag(name="lover", shorthand="l", usage="Are you a cat lover?") + cat_command[].flags.int_flag(name="count", shorthand="c", usage="Number of facts to get.", default=1) + cat_command[].flags.bool_flag(name="lover", shorthand="l", usage="Are you a cat lover?") var dog_command = Arc( Command( diff --git a/examples/requests/persistent_flags.mojo b/examples/requests/persistent_flags.mojo index 6b1faca..09484c0 100644 --- a/examples/requests/persistent_flags.mojo +++ b/examples/requests/persistent_flags.mojo @@ -3,19 +3,19 @@ from prism import Command, Context from python import Python -fn base(context: Context) -> None: +fn base(ctx: Context) -> None: print("This is the base command!") return None -fn print_information(context: Context) -> None: +fn print_information(ctx: Context) -> None: print("Pass cat or dog as a subcommand, and see what you get!") return None -fn get_cat_fact(context: Context) raises -> None: - var flags = context.command[].flags - var lover = flags.get_as_bool("lover") +fn get_cat_fact(ctx: Context) raises -> None: + var flags = ctx.command[].flags + var lover = flags.get_bool("lover") if lover and lover.value(): print("Hello fellow cat lover!") @@ -25,7 +25,7 @@ fn get_cat_fact(context: Context) raises -> None: var url = "https://catfact.ninja/fact" # Send the GET requests - var count = flags.get_as_int("count") + var count = flags.get_int("count") if not count: raise Error("Count flag was not found.") @@ -40,9 +40,9 @@ fn get_cat_fact(context: Context) raises -> None: raise Error("Request failed!") -fn get_dog_breeds(context: Context) raises -> None: - var flags = context.command[].flags - var lover = flags.get_as_bool("lover") +fn get_dog_breeds(ctx: Context) raises -> None: + var flags = ctx.command[].flags + var lover = flags.get_bool("lover") if lover and lover.value(): print("Hello fellow dog lover!") @@ -60,11 +60,11 @@ fn get_dog_breeds(context: Context) raises -> None: raise Error("Request failed!") -fn pre_hook(context: Context) -> None: +fn pre_hook(ctx: Context) -> None: print("Pre-hook executed!") -fn post_hook(context: Context) -> None: +fn post_hook(ctx: Context) -> None: print("Post-hook executed!") @@ -80,7 +80,7 @@ fn main() -> None: persistent_post_run=post_hook, ) ) - get_command[].persistent_flags.add_bool_flag(name="lover", shorthand="l", usage="Are you an animal lover?") + get_command[].persistent_flags.bool_flag(name="lover", shorthand="l", usage="Are you an animal lover?") var cat_command = Arc( Command( @@ -89,7 +89,7 @@ fn main() -> None: erroring_run=get_cat_fact, ) ) - cat_command[].flags.add_int_flag(name="count", shorthand="c", usage="Number of facts to get.") + cat_command[].flags.int_flag(name="count", shorthand="c", usage="Number of facts to get.") var dog_command = Arc( Command( diff --git a/src/prism/args.mojo b/src/prism/args.mojo index ff2a34f..b8fc2bf 100644 --- a/src/prism/args.mojo +++ b/src/prism/args.mojo @@ -5,21 +5,21 @@ from .command import ArgValidator from .context import Context -fn no_args(context: Context) raises -> None: +fn no_args(ctx: Context) raises -> None: """Returns an error if the command has any arguments. Args: - context: The context of the command being executed. + ctx: The context of the command being executed. """ - if len(context.args) > 0: - raise Error(fmt.sprintf("The command `%s` does not take any arguments.", context.command[].name)) + if len(ctx.args) > 0: + raise Error(fmt.sprintf("The command `%s` does not take any arguments.", ctx.command[].name)) -fn arbitrary_args(context: Context) raises -> None: +fn arbitrary_args(ctx: Context) raises -> None: """Never returns an error. Args: - context: The context of the command being executed. + ctx: The context of the command being executed. """ return None @@ -34,14 +34,14 @@ fn minimum_n_args[n: Int]() -> ArgValidator: A function that checks the number of arguments. """ - fn less_than_n_args(context: Context) raises -> None: - if len(context.args) < n: + fn less_than_n_args(ctx: Context) raises -> None: + if len(ctx.args) < n: raise Error( fmt.sprintf( "The command `%s` accepts at least %d argument(s). Received: %d.", - context.command[].name, + ctx.command[].name, n, - len(context.args), + len(ctx.args), ) ) @@ -58,14 +58,14 @@ fn maximum_n_args[n: Int]() -> ArgValidator: A function that checks the number of arguments. """ - fn more_than_n_args(context: Context) raises -> None: - if len(context.args) > n: + fn more_than_n_args(ctx: Context) raises -> None: + if len(ctx.args) > n: raise Error( fmt.sprintf( "The command `%s` accepts at most %d argument(s). Received: %d.", - context.command[].name, + ctx.command[].name, n, - len(context.args), + len(ctx.args), ) ) @@ -82,30 +82,30 @@ fn exact_args[n: Int]() -> ArgValidator: A function that checks the number of arguments. """ - fn exactly_n_args(context: Context) raises -> None: - if len(context.args) != n: + fn exactly_n_args(ctx: Context) raises -> None: + if len(ctx.args) != n: raise Error( fmt.sprintf( "The command `%s` accepts exactly %d argument(s). Received: %d.", - context.command[].name, + ctx.command[].name, n, - len(context.args), + len(ctx.args), ) ) return exactly_n_args -fn valid_args(context: Context) raises -> None: +fn valid_args(ctx: Context) raises -> None: """Returns an error if threre are any positional args that are not in the command's `valid_args`. Args: - context: The context of the command being executed. + ctx: The context of the command being executed. """ - if context.command[].valid_args: - for arg in context.args: - if arg[] not in context.command[].valid_args: - raise Error(fmt.sprintf("Invalid argument: `%s`, for the command `%s`.", arg[], context.command[].name)) + if ctx.command[].valid_args: + for arg in ctx.args: + if arg[] not in ctx.command[].valid_args: + raise Error(fmt.sprintf("Invalid argument: `%s`, for the command `%s`.", arg[], ctx.command[].name)) fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: @@ -119,15 +119,15 @@ fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: A function that checks the number of arguments. """ - fn range_n_args(context: Context) raises -> None: - if len(context.args) < minimum or len(context.args) > maximum: + fn range_n_args(ctx: Context) raises -> None: + if len(ctx.args) < minimum or len(ctx.args) > maximum: raise Error( fmt.sprintf( "The command `%s`, accepts between %d to %d argument(s). Received: %d.", - context.command[].name, + ctx.command[].name, minimum, maximum, - len(context.args), + len(ctx.args), ) ) @@ -145,9 +145,9 @@ fn match_all[arg_validators: List[ArgValidator]]() -> ArgValidator: A function that checks all the arguments using the arg_validators list.. """ - fn match_all_args(context: Context) raises -> None: + fn match_all_args(ctx: Context) raises -> None: for i in range(len(arg_validators)): - arg_validators[i](context) + arg_validators[i](ctx) return match_all_args diff --git a/src/prism/command.mojo b/src/prism/command.mojo index 6bd1e98..08796a3 100644 --- a/src/prism/command.mojo +++ b/src/prism/command.mojo @@ -89,13 +89,13 @@ fn default_help(inout command: Arc[Command]) -> String: return mog.join_vertical(mog.left, description, options, commands) -alias CommandFunction = fn (context: Context) -> None +alias CommandFunction = fn (ctx: Context) -> None """The function for a command to run.""" -alias CommandFunctionErr = fn (context: Context) raises -> None +alias CommandFunctionErr = fn (ctx: Context) raises -> None """The function for a command to run that can error.""" alias HelpFunction = fn (inout command: Arc[Command]) -> String """The function for a help function.""" -alias ArgValidator = fn (context: Context) raises -> None +alias ArgValidator = fn (ctx: Context) raises -> None """The function for an argument validator.""" alias ParentVisitorFn = fn (parent: Arc[Command]) capturing -> None """The function for visiting parents of a command.""" @@ -114,7 +114,7 @@ struct Command(CollectionElement): from memory import Arc from prism import Command, Context - fn test(context: Context) -> None: + fn test(ctx: Context) -> None: print("Hello from Chromeria!") fn main(): @@ -257,7 +257,7 @@ struct Command(CollectionElement): self.local_flags = FlagSet() self.persistent_flags = FlagSet() self._inherited_flags = FlagSet() - self.flags.add_bool_flag(name="help", shorthand="h", usage="Displays help information about the command.") + self.flags.bool_flag(name="help", shorthand="h", usage="Displays help information about the command.") fn __moveinit__(inout self, owned existing: Self): self.name = existing.name^ @@ -368,60 +368,60 @@ struct Command(CollectionElement): return command, remaining_args - fn _execute_pre_run_hooks(self, context: Context, parents: List[Arc[Self]]) raises -> None: + fn _execute_pre_run_hooks(self, ctx: Context, parents: List[Arc[Self]]) raises -> None: """Runs the pre-run hooks for the command.""" try: # Run the persistent pre-run hooks. for parent in parents: if parent[][].persistent_erroring_pre_run: - parent[][].persistent_erroring_pre_run.value()(context) + parent[][].persistent_erroring_pre_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: break else: if parent[][].persistent_pre_run: - parent[][].persistent_pre_run.value()(context) + parent[][].persistent_pre_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: break # Run the pre-run hooks. - if context.command[].pre_run: - context.command[].pre_run.value()(context) - elif context.command[].erroring_pre_run: - context.command[].erroring_pre_run.value()(context) + if ctx.command[].pre_run: + ctx.command[].pre_run.value()(ctx) + elif ctx.command[].erroring_pre_run: + ctx.command[].erroring_pre_run.value()(ctx) except e: - print("Failed to run pre-run hooks for command: " + context.command[].name) + print("Failed to run pre-run hooks for command: " + ctx.command[].name) raise e - fn _execute_post_run_hooks(self, context: Context, parents: List[Arc[Self]]) raises -> None: + fn _execute_post_run_hooks(self, ctx: Context, parents: List[Arc[Self]]) raises -> None: """Runs the pre-run hooks for the command.""" try: # Run the persistent post-run hooks. for parent in parents: if parent[][].persistent_erroring_post_run: - parent[][].persistent_erroring_post_run.value()(context) + parent[][].persistent_erroring_post_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: break else: if parent[][].persistent_post_run: - parent[][].persistent_post_run.value()(context) + parent[][].persistent_post_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: break # Run the post-run hooks. - if context.command[].post_run: - context.command[].post_run.value()(context) - elif context.command[].erroring_post_run: - context.command[].erroring_post_run.value()(context) + if ctx.command[].post_run: + ctx.command[].post_run.value()(ctx) + elif ctx.command[].erroring_post_run: + ctx.command[].erroring_post_run.value()(ctx) except e: - print("Failed to run post-run hooks for command: " + context.command[].name, file=2) + print("Failed to run post-run hooks for command: " + ctx.command[].name, file=2) raise e fn execute(inout self) -> None: @@ -471,7 +471,7 @@ struct Command(CollectionElement): panic(e) # Check if the help flag was passed - var help_passed = command[].flags.get_as_bool("help") + var help_passed = command[].flags.get_bool("help") if help_passed.value() == True: print(command[].help(command)) return None @@ -484,16 +484,16 @@ struct Command(CollectionElement): validate_flag_groups(command[].flags) # Validate the remaining arguments - var context = Context(command, remaining_args) - command[].arg_validator(context) + var ctx = Context(command, remaining_args) + command[].arg_validator(ctx) # Run the function's commands. - self._execute_pre_run_hooks(context, parents) + self._execute_pre_run_hooks(ctx, parents) if command[].run: - command[].run.value()(context) + command[].run.value()(ctx) else: - command[].erroring_run.value()(context) - self._execute_post_run_hooks(context, parents) + command[].erroring_run.value()(ctx) + self._execute_post_run_hooks(ctx, parents) except e: panic(e) diff --git a/src/prism/flag.mojo b/src/prism/flag.mojo index 2779213..43c849c 100644 --- a/src/prism/flag.mojo +++ b/src/prism/flag.mojo @@ -1,6 +1,8 @@ from collections import Optional, Dict +# TODO: When we have trait objects, switch to using actual flag structs per type instead of +# needing to cast values to and from string. @value struct Flag(RepresentableCollectionElement, Stringable, Formattable): """Represents a flag that can be passed via the command line. diff --git a/src/prism/flag_parser.mojo b/src/prism/flag_parser.mojo index cb4d4c5..efe4b32 100644 --- a/src/prism/flag_parser.mojo +++ b/src/prism/flag_parser.mojo @@ -37,7 +37,7 @@ struct FlagParser: raise Error("Command does not accept the flag supplied: " + name) # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). - if flags.get_as_bool(name): + if flags.get_bool(name): return name, String("True"), 1 if self.index + 1 >= len(arguments): @@ -82,7 +82,7 @@ struct FlagParser: var name = result.value() # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). - if flags.get_as_bool(name): + if flags.get_bool(name): return name, String("True"), 1 if self.index + 1 >= len(arguments): diff --git a/src/prism/flag_set.mojo b/src/prism/flag_set.mojo index 668ef9c..203f2c9 100644 --- a/src/prism/flag_set.mojo +++ b/src/prism/flag_set.mojo @@ -1,23 +1,24 @@ from collections import Optional, Dict, InlineList +from utils import Variant import gojo.fmt from .flag import Flag from .util import panic, string_to_bool, string_to_float, split from .flag_parser import FlagParser from .transform import ( - get_as_string, - get_as_bool, - get_as_int, - get_as_int8, - get_as_int16, - get_as_int32, - get_as_int64, - get_as_uint8, - get_as_uint16, - get_as_uint32, - get_as_uint64, - get_as_float16, - get_as_float32, - get_as_float64, + get_string, + get_bool, + get_int, + get_int8, + get_int16, + get_int32, + get_int64, + get_uint8, + get_uint16, + get_uint32, + get_uint64, + as_float16, + as_float32, + as_float64, ) alias FlagVisitorFn = fn (Flag) capturing -> None @@ -33,6 +34,7 @@ alias REQUIRED_AS_GROUP = "REQUIRED_AS_GROUP" alias ONE_REQUIRED = "ONE_REQUIRED" alias MUTUALLY_EXCLUSIVE = "MUTUALLY_EXCLUSIVE" +alias FLAG_TYPES = ["String", "Bool", "Int", "Int8", "Int16", "Int32", "Int64", "UInt8", "UInt16", "UInt32", "UInt64", "Float16", "Float32", "Float64"] @value struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparable): @@ -65,22 +67,13 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl return self.flags.__bool__() fn __contains__(self, value: Flag) -> Bool: - for flag in self.flags: - if flag[] == value: - return True - return False + return value in self.flags fn __eq__(self, other: Self) -> Bool: - if len(self.flags) != len(other.flags): - return False - - for i in range(len(self.flags)): - if self.flags[i] != other.flags[i]: - return False - return True + return self.flags == other.flags fn __ne__(self, other: Self) -> Bool: - return not self == other + return self.flags != other.flags fn __add__(inout self, other: Self) -> Self: var new = Self(self) @@ -91,34 +84,38 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl fn __iadd__(inout self, other: Self): self.merge(other) - fn lookup(ref [_]self, name: String) -> Optional[Reference[Flag, __lifetime_of(self.flags)]]: - """Returns an mutable or immutable reference to a Flag with the given name. - Mutable if FlagSet is mutable, immutable if FlagSet is immutable. - - Returns: - Optional Reference to the Flag. - """ - for i in range(len(self.flags)): - if self.flags[i].name == name: - return Reference(self.flags[i]) - - return None - - fn lookup_with_type( - ref [_]self, name: String, type: String + fn lookup( + ref [_] self, name: String, type: String = "" ) -> Optional[Reference[Flag, __lifetime_of(self.flags)]]: """Returns an mutable or immutable reference to a Flag with the given name. Mutable if FlagSet is mutable, immutable if FlagSet is immutable. - type: The type of the flag to return. + Args: + name: The name of the Flag to lookup. + type: The type of the Flag to lookup. Returns: Optional Reference to the Flag. """ - for i in range(len(self.flags)): - if self.flags[i].name == name and self.flags[i].type == type: - return Reference(self.flags[i]) + if type == "": + for i in range(len(self.flags)): + if self.flags[i].name == name: + return Reference(self.flags[i]) + else: + for i in range(len(self.flags)): + if self.flags[i].name == name and self.flags[i].type == type: + return Reference(self.flags[i]) + return None + + fn lookup_name(self, shorthand: String) -> Optional[String]: + """Returns the name of a flag given its shorthand. + Args: + shorthand: The shorthand of the flag to lookup. + """ + for flag in self.flags: + if flag[].shorthand and flag[].shorthand == shorthand: + return flag[].name return None fn get_as[ @@ -126,64 +123,61 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl ](self, name: String) -> Optional[R]: return transform(self, name) - fn get_as_string(self, name: String) -> Optional[String]: - """Returns the value of a flag as a String. - If it isn't set, then return the default value. - If there isn't a flag of the type specified, then return None. - """ - return self.get_as[R=String, transform=get_as_string](name) + fn get_string(self, name: String) -> Optional[String]: + """Returns the value of a flag as a String. If it isn't set, then return the default value.""" + return self.get_as[R=String, transform=get_string](name) - fn get_as_bool(self, name: String) -> Optional[Bool]: + fn get_bool(self, name: String) -> Optional[Bool]: """Returns the value of a flag as a Bool. If it isn't set, then return the default value.""" - return self.get_as[R=Bool, transform=get_as_bool](name) + return self.get_as[R=Bool, transform=get_bool](name) - fn get_as_int(self, name: String) -> Optional[Int]: + fn get_int(self, name: String) -> Optional[Int]: """Returns the value of a flag as an Int. If it isn't set, then return the default value.""" - return self.get_as[R=Int, transform=get_as_int](name) + return self.get_as[R=Int, transform=get_int](name) - fn get_as_int8(self, name: String) -> Optional[Int8]: + fn get_int8(self, name: String) -> Optional[Int8]: """Returns the value of a flag as a Int8. If it isn't set, then return the default value.""" - return self.get_as[R=Int8, transform=get_as_int8](name) + return self.get_as[R=Int8, transform=get_int8](name) - fn get_as_int16(self, name: String) -> Optional[Int16]: + fn get_int16(self, name: String) -> Optional[Int16]: """Returns the value of a flag as a Int16. If it isn't set, then return the default value.""" - return self.get_as[R=Int16, transform=get_as_int16](name) + return self.get_as[R=Int16, transform=get_int16](name) - fn get_as_int32(self, name: String) -> Optional[Int32]: + fn get_int32(self, name: String) -> Optional[Int32]: """Returns the value of a flag as a Int32. If it isn't set, then return the default value.""" - return self.get_as[R=Int32, transform=get_as_int32](name) + return self.get_as[R=Int32, transform=get_int32](name) - fn get_as_int64(self, name: String) -> Optional[Int64]: + fn get_int64(self, name: String) -> Optional[Int64]: """Returns the value of a flag as a Int64. If it isn't set, then return the default value.""" - return self.get_as[R=Int64, transform=get_as_int64](name) + return self.get_as[R=Int64, transform=get_int64](name) - fn get_as_uint8(self, name: String) -> Optional[UInt8]: + fn get_uint8(self, name: String) -> Optional[UInt8]: """Returns the value of a flag as a UInt8. If it isn't set, then return the default value.""" - return self.get_as[R=UInt8, transform=get_as_uint8](name) + return self.get_as[R=UInt8, transform=get_uint8](name) - fn get_as_uint16(self, name: String) -> Optional[UInt16]: + fn get_uint16(self, name: String) -> Optional[UInt16]: """Returns the value of a flag as a UInt16. If it isn't set, then return the default value.""" - return self.get_as[R=UInt16, transform=get_as_uint16](name) + return self.get_as[R=UInt16, transform=get_uint16](name) - fn get_as_uint32(self, name: String) -> Optional[UInt32]: + fn get_uint32(self, name: String) -> Optional[UInt32]: """Returns the value of a flag as a UInt32. If it isn't set, then return the default value.""" - return self.get_as[R=UInt32, transform=get_as_uint32](name) + return self.get_as[R=UInt32, transform=get_uint32](name) - fn get_as_uint64(self, name: String) -> Optional[UInt64]: + fn get_uint64(self, name: String) -> Optional[UInt64]: """Returns the value of a flag as a UInt64. If it isn't set, then return the default value.""" - return self.get_as[R=UInt64, transform=get_as_uint64](name) + return self.get_as[R=UInt64, transform=get_uint64](name) - fn get_as_float16(self, name: String) -> Optional[Float16]: + fn as_float16(self, name: String) -> Optional[Float16]: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float16, transform=get_as_float16](name) + return self.get_as[R=Float16, transform=as_float16](name) - fn get_as_float32(self, name: String) -> Optional[Float32]: + fn as_float32(self, name: String) -> Optional[Float32]: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float32, transform=get_as_float32](name) + return self.get_as[R=Float32, transform=as_float32](name) - fn get_as_float64(self, name: String) -> Optional[Float64]: + fn as_float64(self, name: String) -> Optional[Float64]: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float64, transform=get_as_float64](name) + return self.get_as[R=Float64, transform=as_float64](name) fn names(self) -> List[String]: """Returns a list of names of all flags in the flag set.""" @@ -200,101 +194,73 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl result.append(flag[].shorthand) return result - fn lookup_name(self, shorthand: String) -> Optional[String]: - """Returns the name of a flag given its shorthand. - - Args: - shorthand: The shorthand of the flag to lookup. - """ - for flag in self.flags: - if flag[].shorthand and flag[].shorthand == shorthand: - return flag[].name - return None - - fn _add_flag( - inout self, name: String, usage: String, default: String, type: String, shorthand: String = "" - ) -> None: - """Adds a flag to the flag set. - Valid type values: [String, Bool, Int, Int8, Int16, Int32, Int64, - UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64] - - Args: - name: The name of the flag. - usage: The usage of the flag. - default: The default value of the flag. - type: The type of the flag. - shorthand: The shorthand of the flag. - """ - var flag = Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type=type) - self.flags.append(flag) - - fn add_bool_flag( + fn bool_flag( inout self, name: String, usage: String, shorthand: String = "", default: Bool = False, ) -> None: - """Adds a Bool flag to the flag set.""" - self._add_flag(name, usage, str(default), "Bool", shorthand) + """Adds a `Bool` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Bool")) - fn add_string_flag( + fn string_flag( inout self, name: String, usage: String, shorthand: String = "", default: String = "", ) -> None: - """Adds a String flag to the flag set.""" - self._add_flag(name, usage, str(default), "String", shorthand) + """Adds a `String` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="String")) - fn add_int_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int = 0) -> None: - """Adds an Int flag to the flag set.""" - self._add_flag(name, usage, str(default), "Int", shorthand) + fn int_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int = 0) -> None: + """Adds an `Int` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int")) - fn add_int8_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int8 = 0) -> None: - """Adds an Int8 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Int8", shorthand) + fn int8_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int8 = 0) -> None: + """Adds an `Int8` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int8")) - fn add_int16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int16 = 0) -> None: - """Adds an Int16 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Int16", shorthand) + fn int16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int16 = 0) -> None: + """Adds an `Int16` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int16")) - fn add_int32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int32 = 0) -> None: - """Adds an Int32 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Int32", shorthand) + fn int32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int32 = 0) -> None: + """Adds an `Int32` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int32")) - fn add_int64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int64 = 0) -> None: - """Adds an Int64 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Int64", shorthand) + fn int64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int64 = 0) -> None: + """Adds an `Int64` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int64")) - fn add_uint8_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt8 = 0) -> None: - """Adds a UInt8 flag to the flag set.""" - self._add_flag(name, usage, str(default), "UInt8", shorthand) + fn uint8_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt8 = 0) -> None: + """Adds a `UInt8` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt8")) - fn add_uint16_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt16 = 0) -> None: - """Adds a UInt16 flag to the flag set.""" - self._add_flag(name, usage, str(default), "UInt16", shorthand) + fn uint16_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt16 = 0) -> None: + """Adds a `UInt16` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt16")) - fn add_uint32_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt32 = 0) -> None: - """Adds a UInt32 flag to the flag set.""" - self._add_flag(name, usage, str(default), "UInt32", shorthand) + fn uint32_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt32 = 0) -> None: + """Adds a `UInt32` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt32")) - fn add_uint64_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt64 = 0) -> None: - """Adds a UInt64 flag to the flag set.""" - self._add_flag(name, usage, str(default), "UInt64", shorthand) + fn uint64_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt64 = 0) -> None: + """Adds a `UInt64` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt64")) - fn add_float16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float16 = 0) -> None: - """Adds a Float16 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Float16", shorthand) + fn float16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float16 = 0) -> None: + """Adds a `Float16` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float16")) - fn add_float32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float32 = 0) -> None: - """Adds a Float32 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Float32", shorthand) + fn float32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float32 = 0) -> None: + """Adds a `Float32` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float32")) - fn add_float64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float64 = 0) -> None: - """Adds a Float64 flag to the flag set.""" - self._add_flag(name, usage, str(default), "Float64", shorthand) + fn float64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float64 = 0) -> None: + """Adds a `Float64` flag to the flag set.""" + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float64")) fn set_annotation(inout self, name: String, key: String, values: String) raises -> None: """Sets an annotation for a flag. @@ -337,9 +303,9 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl self.set_annotation(name, annotation_type, names) except e: print( - String("FlagSet.set_as: Failed to set flag, {}, with the following annotation: {}").format( - name, annotation_type - ), + String( + "FlagSet.set_as: Failed to set flag, {}, with the following annotation: {}" + ).format(name, annotation_type), file=2, ) raise e @@ -376,7 +342,6 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl new_set.visit_all[add_flag]() - # TODO: This parsing is dirty atm, will come back around and clean it up. fn from_args(inout self, arguments: List[String]) raises -> List[String]: """Parses flags and args from the args passed via the command line and adds them to their appropriate collections. diff --git a/src/prism/transform.mojo b/src/prism/transform.mojo index 97357c4..3b1ff30 100644 --- a/src/prism/transform.mojo +++ b/src/prism/transform.mojo @@ -3,14 +3,14 @@ from .flag_set import FlagSet from .util import panic, string_to_bool, string_to_float -fn get_as_string(flag_set: FlagSet, name: String) -> Optional[String]: +fn get_string(flag_set: FlagSet, name: String) -> Optional[String]: """Returns the value of a flag as a String. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var flag = flag_set.lookup_with_type(name, "String") + var flag = flag_set.lookup(name, "String") if not flag: return None @@ -22,14 +22,14 @@ fn get_as_string(flag_set: FlagSet, name: String) -> Optional[String]: return flag.value()[].get_with_transform[get]() -fn get_as_bool(flag_set: FlagSet, name: String) -> Optional[Bool]: +fn get_bool(flag_set: FlagSet, name: String) -> Optional[Bool]: """Returns the value of a flag as a Bool. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var result = flag_set.lookup_with_type(name, "Bool") + var result = flag_set.lookup(name, "Bool") if not result: return None @@ -40,14 +40,14 @@ fn get_as_bool(flag_set: FlagSet, name: String) -> Optional[Bool]: return string_to_bool(flag[].value.value()) -fn get_as_int(flag_set: FlagSet, name: String) -> Optional[Int]: +fn get_int(flag_set: FlagSet, name: String) -> Optional[Int]: """Returns the value of a flag as an Int. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var result = flag_set.lookup_with_type(name, "Int") + var result = flag_set.lookup(name, "Int") if not result: return None @@ -63,154 +63,154 @@ fn get_as_int(flag_set: FlagSet, name: String) -> Optional[Int]: return None -fn get_as_int8(flag_set: FlagSet, name: String) -> Optional[Int8]: +fn get_int8(flag_set: FlagSet, name: String) -> Optional[Int8]: """Returns the value of a flag as a Int8. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return Int8(value.value()) -fn get_as_int16(flag_set: FlagSet, name: String) -> Optional[Int16]: +fn get_int16(flag_set: FlagSet, name: String) -> Optional[Int16]: """Returns the value of a flag as a Int16. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return Int16(value.value()) -fn get_as_int32(flag_set: FlagSet, name: String) -> Optional[Int32]: +fn get_int32(flag_set: FlagSet, name: String) -> Optional[Int32]: """Returns the value of a flag as a Int32. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return Int32(value.value()) -fn get_as_int64(flag_set: FlagSet, name: String) -> Optional[Int64]: +fn get_int64(flag_set: FlagSet, name: String) -> Optional[Int64]: """Returns the value of a flag as a Int64. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return Int64(value.value()) -fn get_as_uint8(flag_set: FlagSet, name: String) -> Optional[UInt8]: +fn get_uint8(flag_set: FlagSet, name: String) -> Optional[UInt8]: """Returns the value of a flag as a UInt8. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return UInt8(value.value()) -fn get_as_uint16(flag_set: FlagSet, name: String) -> Optional[UInt16]: +fn get_uint16(flag_set: FlagSet, name: String) -> Optional[UInt16]: """Returns the value of a flag as a UInt16. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return UInt16(value.value()) -fn get_as_uint32(flag_set: FlagSet, name: String) -> Optional[UInt32]: +fn get_uint32(flag_set: FlagSet, name: String) -> Optional[UInt32]: """Returns the value of a flag as a UInt32. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return UInt32(value.value()) -fn get_as_uint64(flag_set: FlagSet, name: String) -> Optional[UInt64]: +fn get_uint64(flag_set: FlagSet, name: String) -> Optional[UInt64]: """Returns the value of a flag as a UInt64. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_int(name) + var value = flag_set.get_int(name) if not value: return None return UInt64(value.value()) -fn get_as_float16(flag_set: FlagSet, name: String) -> Optional[Float16]: +fn as_float16(flag_set: FlagSet, name: String) -> Optional[Float16]: """Returns the value of a flag as a Float64. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_float64(name) + var value = flag_set.as_float64(name) if not value: return None return value.value().cast[DType.float16]() -fn get_as_float32(flag_set: FlagSet, name: String) -> Optional[Float32]: +fn as_float32(flag_set: FlagSet, name: String) -> Optional[Float32]: """Returns the value of a flag as a Float64. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var value = flag_set.get_as_float64(name) + var value = flag_set.as_float64(name) if not value: return None return value.value().cast[DType.float32]() -fn get_as_float64(flag_set: FlagSet, name: String) -> Optional[Float64]: +fn as_float64(flag_set: FlagSet, name: String) -> Optional[Float64]: """Returns the value of a flag as a Float64. If it isn't set, then return the default value. Args: flag_set: The FlagSet to get the value from. name: The name of the flag to return. """ - var result = flag_set.lookup_with_type(name, "Float64") + var result = flag_set.lookup(name, "Float64") if not result: return None diff --git a/test/test_command.mojo b/test/test_command.mojo index 37a1f05..48f2ea5 100644 --- a/test/test_command.mojo +++ b/test/test_command.mojo @@ -5,7 +5,7 @@ from prism.flag_set import FlagSet def test_command_operations(): - fn dummy(context: Context) -> None: + fn dummy(ctx: Context) -> None: return None var cmd = Arc(Command(name="root", description="Base command.", run=dummy)) @@ -16,7 +16,7 @@ def test_command_operations(): var child_cmd = Arc(Command(name="child", description="Child command.", run=dummy)) cmd[].add_subcommand(child_cmd) - child_cmd[].flags.add_string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + child_cmd[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") testing.assert_equal(child_cmd[].full_name(), "root child") diff --git a/test/test_flags.mojo b/test/test_flags.mojo index 5ac973b..64406b6 100644 --- a/test/test_flags.mojo +++ b/test/test_flags.mojo @@ -20,19 +20,19 @@ def test_string_to_float(): def test_get_flags(): var flag_set = FlagSet() - flag_set.add_string_flag("key", "description", "default") - flag_set.add_bool_flag("flag", "description", "False") + flag_set.string_flag("key", "description", "default") + flag_set.bool_flag("flag", "description", "False") var flags = List[String]("--key=value", "positional", "--flag") _ = flag_set.from_args(flags) - testing.assert_equal(flag_set.get_as_string("key").value(), "value") - testing.assert_equal(flag_set.get_as_bool("flag").value(), True) + testing.assert_equal(flag_set.get_string("key").value(), "value") + testing.assert_equal(flag_set.get_bool("flag").value(), True) def test_parse_flag(): var flag_set = FlagSet() - flag_set.add_string_flag(name="key", usage="description", default="default") - flag_set.add_bool_flag(name="flag", usage="description", default=False) + flag_set.string_flag(name="key", usage="description", default="default") + flag_set.bool_flag(name="flag", usage="description", default=False) var parser = FlagParser() var name: String @@ -51,8 +51,8 @@ def test_parse_flag(): def test_parse_shorthand_flag(): var flag_set = FlagSet() - flag_set.add_string_flag(name="key", usage="description", default="default", shorthand="k") - flag_set.add_bool_flag(name="flag", usage="description", default=False, shorthand="f") + flag_set.string_flag(name="key", usage="description", default="default", shorthand="k") + flag_set.bool_flag(name="flag", usage="description", default=False, shorthand="f") var parser = FlagParser() var name: String From 60cade35c8a4ff57ce5ed4d1309b5586dc53a58c Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Wed, 9 Oct 2024 15:29:03 -0500 Subject: [PATCH 2/4] lots of cleanup and breaking changes :) --- examples/aliases.mojo | 4 +- examples/arg_validators.mojo | 15 +- examples/chromeria.mojo | 4 +- examples/fg_child.mojo | 4 +- examples/fg_parent.mojo | 12 +- examples/hello_world.mojo | 15 +- examples/printer/mojoproject.toml | 2 +- examples/printer/printer.mojo | 28 +-- examples/requests/mojoproject.toml | 2 +- examples/requests/nested.mojo | 16 +- examples/requests/persistent_flags.mojo | 18 +- scripts/examples.sh | 18 +- src/prism/__init__.mojo | 4 +- src/prism/args.mojo | 14 +- src/prism/command.mojo | 166 +++++++++-------- src/prism/context.mojo | 2 +- src/prism/flag.mojo | 10 +- src/prism/flag_group.mojo | 12 +- src/prism/flag_parser.mojo | 19 +- src/prism/flag_set.mojo | 172 +++++++++--------- src/prism/transform.mojo | 226 ------------------------ test/test_args.mojo | 20 +-- test/test_command.mojo | 6 +- test/test_flags.mojo | 12 +- 24 files changed, 288 insertions(+), 513 deletions(-) delete mode 100644 src/prism/transform.mojo diff --git a/examples/aliases.mojo b/examples/aliases.mojo index 4518303..898c213 100644 --- a/examples/aliases.mojo +++ b/examples/aliases.mojo @@ -13,13 +13,13 @@ fn tool_func(ctx: Context) -> None: fn main() -> None: var root = Command( name="my", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) var print_tool = Arc( Command( - name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") + name="tool", usage="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") ) ) diff --git a/examples/arg_validators.mojo b/examples/arg_validators.mojo index 66d1397..0c7ba9f 100644 --- a/examples/arg_validators.mojo +++ b/examples/arg_validators.mojo @@ -23,32 +23,33 @@ fn hello(ctx: Context) -> None: fn main() -> None: var root = Command( name="hello", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) - var no_args_command = Arc(Command(name="no_args", description="This is a dummy command!", run=hello)) + var no_args_command = Arc(Command(name="no_args", usage="This is a dummy command!", run=hello)) no_args_command[].arg_validator = no_args var valid_args_command = Arc( Command( name="valid_args", - description="This is a dummy command!", + usage="This is a dummy command!", run=hello, + valid_args=List[String]("Pineapple") ) ) valid_args_command[].arg_validator = valid_args - var minimum_n_args_command = Arc(Command(name="minimum_n_args", description="This is a dummy command!", run=hello)) + var minimum_n_args_command = Arc(Command(name="minimum_n_args", usage="This is a dummy command!", run=hello)) minimum_n_args_command[].arg_validator = minimum_n_args[4]() - var maximum_n_args_command = Arc(Command(name="maximum_n_args", description="This is a dummy command!", run=hello)) + var maximum_n_args_command = Arc(Command(name="maximum_n_args", usage="This is a dummy command!", run=hello)) maximum_n_args_command[].arg_validator = maximum_n_args[1]() - var exact_args_command = Arc(Command(name="exact_args", description="This is a dummy command!", run=hello)) + var exact_args_command = Arc(Command(name="exact_args", usage="This is a dummy command!", run=hello)) exact_args_command[].arg_validator = exact_args[1]() - var range_args_command = Arc(Command(name="range_args", description="This is a dummy command!", run=hello)) + var range_args_command = Arc(Command(name="range_args", usage="This is a dummy command!", run=hello)) range_args_command[].arg_validator = range_args[0, 1]() root.add_subcommand(no_args_command) diff --git a/examples/chromeria.mojo b/examples/chromeria.mojo index 3aeffb8..d95f3d4 100644 --- a/examples/chromeria.mojo +++ b/examples/chromeria.mojo @@ -13,11 +13,11 @@ fn hello(ctx: Context) -> None: fn main() -> None: var root = Command( name="hello", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) - var hello_command = Arc(Command(name="chromeria", description="This is a dummy command!", run=hello)) + var hello_command = Arc(Command(name="chromeria", usage="This is a dummy command!", run=hello)) root.add_subcommand(hello_command) root.execute() diff --git a/examples/fg_child.mojo b/examples/fg_child.mojo index f9f7eb3..982e058 100644 --- a/examples/fg_child.mojo +++ b/examples/fg_child.mojo @@ -13,7 +13,7 @@ fn tool_func(ctx: Context) -> None: fn main() -> None: var root = Command( name="my", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") @@ -21,7 +21,7 @@ fn main() -> None: root.persistent_flags.string_flag(name="port", shorthand="p", usage="Port") root.mark_persistent_flag_required("required") - var print_tool = Arc(Command(name="tool", description="This is a dummy command!", run=tool_func)) + var print_tool = Arc(Command(name="tool", usage="This is a dummy command!", run=tool_func)) print_tool[].flags.bool_flag(name="also", shorthand="a", usage="Also always required.") print_tool[].flags.string_flag(name="uri", shorthand="u", usage="URI") root.add_subcommand(print_tool) diff --git a/examples/fg_parent.mojo b/examples/fg_parent.mojo index 06cf794..373664d 100644 --- a/examples/fg_parent.mojo +++ b/examples/fg_parent.mojo @@ -2,15 +2,15 @@ from memory import Arc from prism import Command, Context -fn test(ctx: Context) -> None: +fn test(ctx: Context) raises -> None: var host = ctx.command[].flags.get_string("host") var port = ctx.command[].flags.get_string("port") var uri = ctx.command[].flags.get_string("uri") - if uri: - print("URI: ", uri.value()) + if uri != "": + print("URI:", uri) else: - print(host.value(), ":", port.value()) + print(host + ":" + port) fn tool_func(ctx: Context) -> None: @@ -20,8 +20,8 @@ fn tool_func(ctx: Context) -> None: fn main() -> None: var root = Command( name="my", - description="This is a dummy command!", - run=test, + usage="This is a dummy command!", + raising_run=test, ) root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") root.persistent_flags.string_flag(name="host", shorthand="h", usage="Host") diff --git a/examples/hello_world.mojo b/examples/hello_world.mojo index c643fd4..cf46691 100644 --- a/examples/hello_world.mojo +++ b/examples/hello_world.mojo @@ -15,7 +15,7 @@ fn build_printer_command() -> Arc[Command]: var cmd = Arc( Command( name="printer", - description="Print the first arg.", + usage="Print the first arg.", run=printer, ) ) @@ -42,7 +42,7 @@ fn build_say_command() -> Arc[Command]: return Arc( Command( name="say", - description="Say something to someone", + usage="Say something to someone", run=say, ) ) @@ -52,7 +52,7 @@ fn build_hello_command() -> Arc[Command]: var cmd = Arc( Command( name="hello", - description="Say hello to someone", + usage="Say hello to someone", run=say_hello, ) ) @@ -63,7 +63,7 @@ fn build_goodbye_command() -> Arc[Command]: var cmd = Arc( Command( name="goodbye", - description="Say goodbye to someone", + usage="Say goodbye to someone", run=say_goodbye, ) ) @@ -71,7 +71,10 @@ fn build_goodbye_command() -> Arc[Command]: fn test(ctx: Context) -> None: - print(ctx.command[].flags.get_string("env").value()) + try: + print(ctx.command[].flags.get_string("env")) + except: + print("No env flag provided.") for item in ctx.command[].flags.flags: if item[].value: print(item[].name, item[].value.value()) @@ -84,7 +87,7 @@ fn test(ctx: Context) -> None: fn main() -> None: var root = Command( name="tones", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) root.flags.string_flag(name="env", shorthand="e", usage="Environment.", default="") diff --git a/examples/printer/mojoproject.toml b/examples/printer/mojoproject.toml index 04d0cd8..7fd4c2e 100644 --- a/examples/printer/mojoproject.toml +++ b/examples/printer/mojoproject.toml @@ -11,4 +11,4 @@ version = "0.1.0" [dependencies] max = ">=24.5.0,<25" mist = ">=0.1.8,<0.2" -prism = ">=0.1.6,<0.2" +prism = ">=0.1.7,<0.2" diff --git a/examples/printer/printer.mojo b/examples/printer/printer.mojo index e3cc105..59de23b 100644 --- a/examples/printer/printer.mojo +++ b/examples/printer/printer.mojo @@ -1,35 +1,25 @@ from memory import Arc from prism import Command, Context, exact_args -from mist import Style +import mist - -fn printer(ctx: Context) -> None: +fn printer(ctx: Context) raises -> None: if len(ctx.args) <= 0: print("No text to print! Pass in some text as a positional argument.") return None var color = ctx.command[].flags.get_uint32("color") var formatting = ctx.command[].flags.get_string("formatting") - var style = Style() - - if not color: - color = 0xFFFFFF - if not formatting: - formatting = String("") - - if color: - style = style.foreground(color.value()) + var style = mist.Style().foreground(color) - var formatting_value = formatting.or_else("") - if formatting_value == "": + if formatting == "": print(style.render(ctx.args[0])) return None - if formatting.value() == "bold": + if formatting == "bold": style = style.bold() - elif formatting.value() == "underline": + elif formatting == "underline": style = style.underline() - elif formatting.value() == "italic": + elif formatting == "italic": style = style.italic() print(style.render(ctx.args[0])) @@ -49,8 +39,8 @@ fn post_hook(ctx: Context) -> None: fn main() -> None: var root = Command( name="printer", - description="Base command.", - run=printer, + usage="Base command.", + raising_run=printer, pre_run=pre_hook, post_run=post_hook, ) diff --git a/examples/requests/mojoproject.toml b/examples/requests/mojoproject.toml index 814e96c..7ce8840 100644 --- a/examples/requests/mojoproject.toml +++ b/examples/requests/mojoproject.toml @@ -10,5 +10,5 @@ version = "0.1.0" [dependencies] max = ">=24.5.0,<25" -prism = ">=0.1.6,<0.2" +prism = ">=0.1.7,<0.2" requests = ">=2.32.3,<3" diff --git a/examples/requests/nested.mojo b/examples/requests/nested.mojo index 0be5fbf..b255ada 100644 --- a/examples/requests/nested.mojo +++ b/examples/requests/nested.mojo @@ -16,7 +16,7 @@ fn print_information(ctx: Context) -> None: fn get_cat_fact(ctx: Context) raises -> None: var flags = ctx.command[].flags var lover = flags.get_bool("lover") - if lover and lover.value(): + if lover: print("Hello fellow cat lover!") var requests = Python.import_module("requests") @@ -29,7 +29,7 @@ fn get_cat_fact(ctx: Context) raises -> None: if not count: raise Error("Count flag was not found.") - for _ in range(count.value()): + for _ in range(count): var response = requests.get(url) # Check if the request was successful (status code 200) @@ -57,12 +57,12 @@ fn get_dog_breeds(ctx: Context) raises -> None: fn main() -> None: - var root = Command(name="nested", description="Base command.", run=base) + var root = Command(name="nested", usage="Base command.", run=base) var get_command = Arc( Command( name="get", - description="Base command for getting some data.", + usage="Base command for getting some data.", run=print_information, ) ) @@ -70,8 +70,8 @@ fn main() -> None: var cat_command = Arc( Command( name="cat", - description="Get some cat facts!", - erroring_run=get_cat_fact, + usage="Get some cat facts!", + raising_run=get_cat_fact, ) ) cat_command[].flags.int_flag(name="count", shorthand="c", usage="Number of facts to get.", default=1) @@ -80,8 +80,8 @@ fn main() -> None: var dog_command = Arc( Command( name="dog", - description="Get some dog breeds!", - erroring_run=get_dog_breeds, + usage="Get some dog breeds!", + raising_run=get_dog_breeds, ) ) diff --git a/examples/requests/persistent_flags.mojo b/examples/requests/persistent_flags.mojo index 09484c0..64da5d3 100644 --- a/examples/requests/persistent_flags.mojo +++ b/examples/requests/persistent_flags.mojo @@ -16,7 +16,7 @@ fn print_information(ctx: Context) -> None: fn get_cat_fact(ctx: Context) raises -> None: var flags = ctx.command[].flags var lover = flags.get_bool("lover") - if lover and lover.value(): + if lover: print("Hello fellow cat lover!") var requests = Python.import_module("requests") @@ -29,7 +29,7 @@ fn get_cat_fact(ctx: Context) raises -> None: if not count: raise Error("Count flag was not found.") - for _ in range(count.value()): + for _ in range(count): var response = requests.get(url) # Check if the request was successful (status code 200) @@ -43,7 +43,7 @@ fn get_cat_fact(ctx: Context) raises -> None: fn get_dog_breeds(ctx: Context) raises -> None: var flags = ctx.command[].flags var lover = flags.get_bool("lover") - if lover and lover.value(): + if lover: print("Hello fellow dog lover!") var requests = Python.import_module("requests") @@ -69,12 +69,12 @@ fn post_hook(ctx: Context) -> None: fn main() -> None: - var root = Command(name="nested", description="Base command.", run=base) + var root = Command(name="nested", usage="Base command.", run=base) var get_command = Arc( Command( name="get", - description="Base command for getting some data.", + usage="Base command for getting some data.", run=print_information, persistent_pre_run=pre_hook, persistent_post_run=post_hook, @@ -85,8 +85,8 @@ fn main() -> None: var cat_command = Arc( Command( name="cat", - description="Get some cat facts!", - erroring_run=get_cat_fact, + usage="Get some cat facts!", + raising_run=get_cat_fact, ) ) cat_command[].flags.int_flag(name="count", shorthand="c", usage="Number of facts to get.") @@ -94,8 +94,8 @@ fn main() -> None: var dog_command = Arc( Command( name="dog", - description="Get some dog breeds!", - erroring_run=get_dog_breeds, + usage="Get some dog breeds!", + raising_run=get_dog_breeds, ) ) diff --git a/scripts/examples.sh b/scripts/examples.sh index 355fabc..ce63c07 100755 --- a/scripts/examples.sh +++ b/scripts/examples.sh @@ -16,16 +16,16 @@ magic run mojo build $TEMP_DIR/arg_validators.mojo -o $TEMP_DIR/validators echo "[INFO] Running examples..." # Need to run these first examples as part of a mojo project as they have external dependencies. # printer is a portable binary, but nested and persistent_flags are not because they depend on a python library. -cd examples/printer -magic run mojo build printer.mojo -./printer "sample-text" --formatting=underline +# cd examples/printer +# magic run mojo build printer.mojo +# ./printer "sample-text" --formatting=underline -cd ../requests -magic run mojo build nested.mojo -magic run nested get cat --count 3 -l -magic run mojo build persistent_flags.mojo -o persistent -magic run persistent get cat --count 2 --lover -magic run persistent get dog +# cd ../requests +# magic run mojo build nested.mojo +# magic run nested get cat --count 3 -l +# magic run mojo build persistent_flags.mojo -o persistent +# magic run persistent get cat --count 2 --lover +# magic run persistent get dog cd ../.. diff --git a/src/prism/__init__.mojo b/src/prism/__init__.mojo index 5161851..4ec8cb5 100644 --- a/src/prism/__init__.mojo +++ b/src/prism/__init__.mojo @@ -1,7 +1,7 @@ from .command import ( Command, - CommandFunction, - ArgValidator, + CmdFn, + ArgValidatorFn, ) from .args import no_args, valid_args, arbitrary_args, minimum_n_args, maximum_n_args, exact_args, range_args, match_all from .flag import Flag diff --git a/src/prism/args.mojo b/src/prism/args.mojo index b8fc2bf..fd4d215 100644 --- a/src/prism/args.mojo +++ b/src/prism/args.mojo @@ -1,7 +1,7 @@ from memory.arc import Arc from collections.optional import Optional import gojo.fmt -from .command import ArgValidator +from .command import ArgValidatorFn from .context import Context @@ -24,7 +24,7 @@ fn arbitrary_args(ctx: Context) raises -> None: return None -fn minimum_n_args[n: Int]() -> ArgValidator: +fn minimum_n_args[n: Int]() -> ArgValidatorFn: """Returns an error if there is not at least n arguments. Params: @@ -48,7 +48,7 @@ fn minimum_n_args[n: Int]() -> ArgValidator: return less_than_n_args -fn maximum_n_args[n: Int]() -> ArgValidator: +fn maximum_n_args[n: Int]() -> ArgValidatorFn: """Returns an error if there are more than n arguments. Params: @@ -72,7 +72,7 @@ fn maximum_n_args[n: Int]() -> ArgValidator: return more_than_n_args -fn exact_args[n: Int]() -> ArgValidator: +fn exact_args[n: Int]() -> ArgValidatorFn: """Returns an error if there are not exactly n arguments. Params: @@ -108,7 +108,7 @@ fn valid_args(ctx: Context) raises -> None: raise Error(fmt.sprintf("Invalid argument: `%s`, for the command `%s`.", arg[], ctx.command[].name)) -fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: +fn range_args[minimum: Int, maximum: Int]() -> ArgValidatorFn: """Returns an error if there are not exactly n arguments. Params: @@ -135,11 +135,11 @@ fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: # TODO: Having some issues with varadic list of functions, so using List for now. -fn match_all[arg_validators: List[ArgValidator]]() -> ArgValidator: +fn match_all[arg_validators: List[ArgValidatorFn]]() -> ArgValidatorFn: """Returns an error if any of the arg_validators return an error. Params: - arg_validators: A list of ArgValidator functions that check the arguments. + arg_validators: A list of ArgValidatorFn functions that check the arguments. Returns: A function that checks all the arguments using the arg_validators list.. diff --git a/src/prism/command.mojo b/src/prism/command.mojo index 84f8960..b502eec 100644 --- a/src/prism/command.mojo +++ b/src/prism/command.mojo @@ -42,7 +42,7 @@ fn default_help(inout command: Arc[Command]) -> String: """Prints the help information for the command. TODO: Add padding for commands, options, and aliases. """ - var description_style = mog.Style().border(mog.HIDDEN_BORDER) + var usage_style = mog.Style().border(mog.HIDDEN_BORDER) var border_style = mog.Style().border(mog.ROUNDED_BORDER).border_foreground(mog.Color(0x383838)).padding(0, 1) var option_style = mog.Style().foreground(mog.Color(0x81C8BE)) var bold_style = mog.Style().bold() @@ -58,7 +58,7 @@ fn default_help(inout command: Arc[Command]) -> String: _ = builder.write_string(" COMMAND") _ = builder.write_string(" [ARGS]...") - var description = description_style.render(mog.join_vertical(mog.left, str(builder), "\n", cmd[].description)) + var usage = usage_style.render(mog.join_vertical(mog.left, str(builder), "\n", cmd[].usage)) builder = StringBuilder() if cmd[].flags.flags: @@ -73,9 +73,7 @@ fn default_help(inout command: Arc[Command]) -> String: _ = builder.write_string(bold_style.render("Commands")) for i in range(len(cmd[].children)): _ = builder.write_string( - fmt.sprintf( - "\n%s %s", option_style.render(cmd[].children[i][].name), cmd[].children[i][].description - ) + fmt.sprintf("\n%s %s", option_style.render(cmd[].children[i][].name), cmd[].children[i][].usage) ) if i == len(cmd[].children) - 1: @@ -86,16 +84,16 @@ fn default_help(inout command: Arc[Command]) -> String: _ = builder.write_string(fmt.sprintf("\n%s", option_style.render(cmd[].aliases.__str__()))) var commands = border_style.render(str(builder)) - return mog.join_vertical(mog.left, description, options, commands) + return mog.join_vertical(mog.left, usage, options, commands) -alias CommandFunction = fn (ctx: Context) -> None +alias CmdFn = fn (ctx: Context) -> None """The function for a command to run.""" -alias CommandFunctionErr = fn (ctx: Context) raises -> None +alias RaisingCmdFn = fn (ctx: Context) raises -> None """The function for a command to run that can error.""" -alias HelpFunction = fn (inout command: Arc[Command]) -> String -"""The function for a help function.""" -alias ArgValidator = fn (ctx: Context) raises -> None +alias HelpFn = fn (inout command: Arc[Command]) -> String +"""The function to generate help output.""" +alias ArgValidatorFn = fn (ctx: Context) raises -> None """The function for an argument validator.""" alias ParentVisitorFn = fn (parent: Arc[Command]) capturing -> None """The function for visiting parents of a command.""" @@ -120,7 +118,7 @@ struct Command(CollectionElement): fn main(): var command = Command( name="hello", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) command.execute() @@ -135,38 +133,38 @@ struct Command(CollectionElement): var name: String """The name of the command.""" - var description: String + var usage: String """Description of the command.""" var aliases: List[String] """Aliases that can be used instead of the first word in name.""" - var help: HelpFunction + var help: HelpFn """Generates help text.""" - var pre_run: Optional[CommandFunction] + var pre_run: Optional[CmdFn] """A function to run before the run function is executed.""" - var run: Optional[CommandFunction] + var run: Optional[CmdFn] """A function to run when the command is executed.""" - var post_run: Optional[CommandFunction] + var post_run: Optional[CmdFn] """A function to run after the run function is executed.""" - var erroring_pre_run: Optional[CommandFunctionErr] + var raising_pre_run: Optional[RaisingCmdFn] """A raising function to run before the run function is executed.""" - var erroring_run: Optional[CommandFunctionErr] + var raising_run: Optional[RaisingCmdFn] """A raising function to run when the command is executed.""" - var erroring_post_run: Optional[CommandFunctionErr] + var raising_post_run: Optional[RaisingCmdFn] """A raising function to run after the run function is executed.""" - var persistent_pre_run: Optional[CommandFunction] + var persistent_pre_run: Optional[CmdFn] """A function to run before the run function is executed. This persists to children.""" - var persistent_post_run: Optional[CommandFunction] + var persistent_post_run: Optional[CmdFn] """A function to run after the run function is executed. This persists to children.""" - var persistent_erroring_pre_run: Optional[CommandFunctionErr] + var persistent_raising_pre_run: Optional[RaisingCmdFn] """A raising function to run before the run function is executed. This persists to children.""" - var persistent_erroring_post_run: Optional[CommandFunctionErr] + var persistent_raising_post_run: Optional[RaisingCmdFn] """A raising function to run after the run function is executed. This persists to children.""" - var arg_validator: ArgValidator + var arg_validator: ArgValidatorFn """Function to validate arguments passed to the command.""" var valid_args: List[String] """Valid arguments for the command.""" @@ -192,43 +190,43 @@ struct Command(CollectionElement): fn __init__( inout self, name: String, - description: String, + usage: String, aliases: List[String] = List[String](), valid_args: List[String] = List[String](), - run: Optional[CommandFunction] = None, - pre_run: Optional[CommandFunction] = None, - post_run: Optional[CommandFunction] = None, - erroring_run: Optional[CommandFunctionErr] = None, - erroring_pre_run: Optional[CommandFunctionErr] = None, - erroring_post_run: Optional[CommandFunctionErr] = None, - persistent_pre_run: Optional[CommandFunction] = None, - persistent_post_run: Optional[CommandFunction] = None, - persistent_erroring_pre_run: Optional[CommandFunctionErr] = None, - persistent_erroring_post_run: Optional[CommandFunctionErr] = None, + run: Optional[CmdFn] = None, + pre_run: Optional[CmdFn] = None, + post_run: Optional[CmdFn] = None, + raising_run: Optional[RaisingCmdFn] = None, + raising_pre_run: Optional[RaisingCmdFn] = None, + raising_post_run: Optional[RaisingCmdFn] = None, + persistent_pre_run: Optional[CmdFn] = None, + persistent_post_run: Optional[CmdFn] = None, + persistent_raising_pre_run: Optional[RaisingCmdFn] = None, + persistent_raising_post_run: Optional[RaisingCmdFn] = None, ): """ Args: name: The name of the command. - description: The description of the command. + usage: The usage of the command. arg_validator: The function to validate the arguments passed to the command. valid_args: The valid arguments for the command. run: The function to run when the command is executed. pre_run: The function to run before the command is executed. post_run: The function to run after the command is executed. - erroring_run: The function to run when the command is executed that returns an error. - erroring_pre_run: The function to run before the command is executed that returns an error. - erroring_post_run: The function to run after the command is executed that returns an error. + raising_run: The function to run when the command is executed that returns an error. + raising_pre_run: The function to run before the command is executed that returns an error. + raising_post_run: The function to run after the command is executed that returns an error. persisting_pre_run: The function to run before the command is executed. This persists to children. persisting_post_run: The function to run after the command is executed. This persists to children. - persisting_erroring_pre_run: The function to run before the command is executed that returns an error. This persists to children. - persisting_erroring_post_run: The function to run after the command is executed that returns an error. This persists to children. + persisting_raising_pre_run: The function to run before the command is executed that returns an error. This persists to children. + persisting_raising_post_run: The function to run after the command is executed that returns an error. This persists to children. help: The function to generate help text for the command. """ - if not run and not erroring_run: - panic("A command must have a run or erroring_run function.") + if not run and not raising_run: + panic("A command must have a run or raising_run function.") self.name = name - self.description = description + self.usage = usage self.aliases = aliases self.help = default_help @@ -237,14 +235,14 @@ struct Command(CollectionElement): self.run = run self.post_run = post_run - self.erroring_pre_run = erroring_pre_run - self.erroring_run = erroring_run - self.erroring_post_run = erroring_post_run + self.raising_pre_run = raising_pre_run + self.raising_run = raising_run + self.raising_post_run = raising_post_run self.persistent_pre_run = persistent_pre_run self.persistent_post_run = persistent_post_run - self.persistent_erroring_pre_run = persistent_erroring_pre_run - self.persistent_erroring_post_run = persistent_erroring_post_run + self.persistent_raising_pre_run = persistent_raising_pre_run + self.persistent_raising_post_run = persistent_raising_post_run self.arg_validator = arbitrary_args self.valid_args = valid_args @@ -261,7 +259,7 @@ struct Command(CollectionElement): fn __moveinit__(inout self, owned existing: Self): self.name = existing.name^ - self.description = existing.description^ + self.usage = existing.usage^ self.aliases = existing.aliases^ self.help = existing.help @@ -270,14 +268,14 @@ struct Command(CollectionElement): self.run = existing.run^ self.post_run = existing.post_run^ - self.erroring_pre_run = existing.erroring_pre_run^ - self.erroring_run = existing.erroring_run^ - self.erroring_post_run = existing.erroring_post_run^ + self.raising_pre_run = existing.raising_pre_run^ + self.raising_run = existing.raising_run^ + self.raising_post_run = existing.raising_post_run^ self.persistent_pre_run = existing.persistent_pre_run^ self.persistent_post_run = existing.persistent_post_run^ - self.persistent_erroring_pre_run = existing.persistent_erroring_pre_run^ - self.persistent_erroring_post_run = existing.persistent_erroring_post_run^ + self.persistent_raising_pre_run = existing.persistent_raising_pre_run^ + self.persistent_raising_post_run = existing.persistent_raising_post_run^ self.arg_validator = existing.arg_validator self.valid_args = existing.valid_args^ @@ -311,8 +309,8 @@ struct Command(CollectionElement): writer.write("Command(Name: ") writer.write(self.name) - writer.write(", Description: ") - writer.write(self.description) + writer.write(", usage: ") + writer.write(self.usage) if self.aliases: writer.write(", Aliases: ") @@ -340,34 +338,36 @@ struct Command(CollectionElement): return self.parent[0][].root() return self - - fn _parse_command(self, command: Self, arg: String, children: List[Arc[Self]], inout leftover_start: Int) -> (Self, List[Arc[Self]]): + + fn _parse_command( + self, command: Self, arg: String, children: List[Arc[Self]], inout leftover_start: Int + ) -> (Self, List[Arc[Self]]): for command_ref in children: if command_ref[][].name == arg or arg in command_ref[][].aliases: leftover_start += 1 return command_ref[][], command_ref[][].children - + return command, children fn _parse_command_from_args(self, args: List[String]) -> (Self, List[String]): # If there's no children, then the root command is used. if not self.children or not args: return self, args - + var command = self var children = self.children var leftover_start = 0 # Start at 1 to start slice at the first remaining arg, not the last child command. for arg in args: command, children = self._parse_command(command, arg[], children, leftover_start) - + if leftover_start == 0: return self, args - + # If the there are more or equivalent args to the index, then there are remaining args to pass to the command. var remaining_args = List[String]() if len(args) >= leftover_start: - remaining_args = args[leftover_start:len(args)] + remaining_args = args[leftover_start : len(args)] return command, remaining_args @@ -376,8 +376,8 @@ struct Command(CollectionElement): try: # Run the persistent pre-run hooks. for parent in parents: - if parent[][].persistent_erroring_pre_run: - parent[][].persistent_erroring_pre_run.value()(ctx) + if parent[][].persistent_raising_pre_run: + parent[][].persistent_raising_pre_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: @@ -393,8 +393,8 @@ struct Command(CollectionElement): # Run the pre-run hooks. if ctx.command[].pre_run: ctx.command[].pre_run.value()(ctx) - elif ctx.command[].erroring_pre_run: - ctx.command[].erroring_pre_run.value()(ctx) + elif ctx.command[].raising_pre_run: + ctx.command[].raising_pre_run.value()(ctx) except e: print("Failed to run pre-run hooks for command: " + ctx.command[].name) raise e @@ -404,8 +404,8 @@ struct Command(CollectionElement): try: # Run the persistent post-run hooks. for parent in parents: - if parent[][].persistent_erroring_post_run: - parent[][].persistent_erroring_post_run.value()(ctx) + if parent[][].persistent_raising_post_run: + parent[][].persistent_raising_post_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: @@ -421,8 +421,8 @@ struct Command(CollectionElement): # Run the post-run hooks. if ctx.command[].post_run: ctx.command[].post_run.value()(ctx) - elif ctx.command[].erroring_post_run: - ctx.command[].erroring_post_run.value()(ctx) + elif ctx.command[].raising_post_run: + ctx.command[].raising_post_run.value()(ctx) except e: print("Failed to run post-run hooks for command: " + ctx.command[].name, file=2) raise e @@ -447,6 +447,7 @@ struct Command(CollectionElement): # Add all parents to the list to check if they have persistent pre/post hooks. var parents = List[Arc[Self]]() + @parameter fn append_parents(parent: Arc[Self]) capturing -> None: parents.append(parent) @@ -459,20 +460,17 @@ struct Command(CollectionElement): if ENABLE_TRAVERSE_RUN_HOOKS: parents.reverse() - # Get the flags for the command to be executed. try: + # Get the flags for the command to be executed. remaining_args = command.flags.from_args(remaining_args) - except e: - panic(e) - # Check if the help flag was passed - var command_ref = Arc(command) - var help_passed = command.flags.get_bool("help") - if help_passed.value() == True: - print(command.help(command_ref)) - return None + # Check if the help flag was passed + var help_passed = command.flags.get_bool("help") + var command_ref = Arc(command) + if help_passed == True: + print(command.help(command_ref)) + return None - try: # Validate individual required flags (eg: flag is required) validate_required_flags(command.flags) @@ -488,7 +486,7 @@ struct Command(CollectionElement): if command.run: command.run.value()(ctx) else: - command.erroring_run.value()(ctx) + command.raising_run.value()(ctx) self._execute_post_run_hooks(ctx, parents) except e: panic(e) diff --git a/src/prism/context.mojo b/src/prism/context.mojo index dde09b6..9398374 100644 --- a/src/prism/context.mojo +++ b/src/prism/context.mojo @@ -8,4 +8,4 @@ struct Context: def __init__(inout self, command: Arc[Command], args: List[String]) -> None: self.command = command - self.args = args \ No newline at end of file + self.args = args diff --git a/src/prism/flag.mojo b/src/prism/flag.mojo index 43c849c..c1ec162 100644 --- a/src/prism/flag.mojo +++ b/src/prism/flag.mojo @@ -120,7 +120,7 @@ struct Flag(RepresentableCollectionElement, Stringable, Formattable): self.value = value self.changed = True - fn get_with_transform[T: CollectionElement, //, transform: fn (value: String) -> T](self) -> Optional[T]: + fn get_with_transform[T: CollectionElement, //, transform: fn (value: String) -> T](self) -> T: """Returns the value of the flag with a transformation applied to it. Params: @@ -132,3 +132,11 @@ struct Flag(RepresentableCollectionElement, Stringable, Formattable): if self.value: return transform(self.value.value()) return transform(self.default) + + fn value_or_default(self) -> String: + """Returns the value of the flag or the default value if it isn't set. + + Returns: + The value of the flag or the default value. + """ + return self.value.or_else(self.default) diff --git a/src/prism/flag_group.mojo b/src/prism/flag_group.mojo index e9de762..6dc514e 100644 --- a/src/prism/flag_group.mojo +++ b/src/prism/flag_group.mojo @@ -7,7 +7,9 @@ from gojo import fmt fn has_all_flags(flags: FlagSet, flag_names: List[String]) -> Bool: for name in flag_names: - if not flags.lookup(name[]): + try: + _ = flags.lookup(name[]) + except: return False return True @@ -40,7 +42,7 @@ fn process_flag_for_group_annotation( except e: raise Error( String( - "process_flag_for_group_annotation: Failed to set group status for annotation {}: {}" + "process_flag_for_group_annotation: Failed to set group status for annotation {}: {}." ).format(annotation, str(e)) ) @@ -71,7 +73,7 @@ fn validate_required_flag_group(data: Dict[String, Dict[String, Bool]]) -> None: panic( fmt.sprintf( - "if any flags in the group, %s, are set they must all be set; missing %s", + "If any flags in the group, %s, are set they must all be set; missing %s.", keys.__str__(), unset.__str__(), ) @@ -101,7 +103,7 @@ fn validate_one_required_flag_group(data: Dict[String, Dict[String, Bool]]) -> N for key in pair[].value.keys(): keys.append(key[]) - panic(fmt.sprintf("at least one of the flags in the group %s is required", keys.__str__())) + panic(fmt.sprintf("At least one of the flags in the group %s is required.", keys.__str__())) fn validate_mutually_exclusive_flag_group(data: Dict[String, Dict[String, Bool]]) -> None: @@ -129,7 +131,7 @@ fn validate_mutually_exclusive_flag_group(data: Dict[String, Dict[String, Bool]] panic( fmt.sprintf( - "if any flags in the group %s are set none of the others can be; %s were all set", + "If any flags in the group %s are set none of the others can be; %s were all set.", keys.__str__(), set.__str__(), ) diff --git a/src/prism/flag_parser.mojo b/src/prism/flag_parser.mojo index efe4b32..302cdbe 100644 --- a/src/prism/flag_parser.mojo +++ b/src/prism/flag_parser.mojo @@ -37,8 +37,11 @@ struct FlagParser: raise Error("Command does not accept the flag supplied: " + name) # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). - if flags.get_bool(name): + try: + _ = flags.lookup(name, "Bool") return name, String("True"), 1 + except: + pass if self.index + 1 >= len(arguments): raise Error("Flag `" + name + "` requires a value to be set but reached the end of arguments.") @@ -67,8 +70,7 @@ struct FlagParser: var flag = split(argument, "=") var shorthand = flag[0][1:] var value = flag[1] - var name = flags.lookup_name(shorthand).value() - + var name = flags.lookup_name(shorthand) if name not in flags.names(): raise Error("Command does not accept the shorthand flag supplied: " + name) @@ -76,10 +78,9 @@ struct FlagParser: # Flag with value set like "-f " var shorthand = argument[1:] - var result = flags.lookup_name(shorthand) - if not result: + var name = flags.lookup_name(shorthand) + if name not in flags.names(): raise Error("Command does not accept the shorthand flag supplied: " + shorthand) - var name = result.value() # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). if flags.get_bool(name): @@ -126,11 +127,7 @@ struct FlagParser: raise Error("Expected a flag but found: " + argument) # Set the value of the flag. - var flag = flags.lookup(name) - if not flag: - raise Error("No flag found with the name: " + name) - - flag.value()[].set(value) + flags.lookup(name).set(value) self.index += increment_by return remaining_args diff --git a/src/prism/flag_set.mojo b/src/prism/flag_set.mojo index 203f2c9..199dd73 100644 --- a/src/prism/flag_set.mojo +++ b/src/prism/flag_set.mojo @@ -4,22 +4,7 @@ import gojo.fmt from .flag import Flag from .util import panic, string_to_bool, string_to_float, split from .flag_parser import FlagParser -from .transform import ( - get_string, - get_bool, - get_int, - get_int8, - get_int16, - get_int32, - get_int64, - get_uint8, - get_uint16, - get_uint32, - get_uint64, - as_float16, - as_float32, - as_float64, -) + alias FlagVisitorFn = fn (Flag) capturing -> None """Function perform some action while visiting all flags.""" @@ -34,7 +19,23 @@ alias REQUIRED_AS_GROUP = "REQUIRED_AS_GROUP" alias ONE_REQUIRED = "ONE_REQUIRED" alias MUTUALLY_EXCLUSIVE = "MUTUALLY_EXCLUSIVE" -alias FLAG_TYPES = ["String", "Bool", "Int", "Int8", "Int16", "Int32", "Int64", "UInt8", "UInt16", "UInt32", "UInt64", "Float16", "Float32", "Float64"] +alias FLAG_TYPES = [ + "String", + "Bool", + "Int", + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Float16", + "Float32", + "Float64", +] + @value struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparable): @@ -84,9 +85,7 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl fn __iadd__(inout self, other: Self): self.merge(other) - fn lookup( - ref [_] self, name: String, type: String = "" - ) -> Optional[Reference[Flag, __lifetime_of(self.flags)]]: + fn lookup(ref [_]self, name: String, type: String = "") raises -> ref [__lifetime_of(self.flags)] Flag: """Returns an mutable or immutable reference to a Flag with the given name. Mutable if FlagSet is mutable, immutable if FlagSet is immutable. @@ -100,14 +99,15 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl if type == "": for i in range(len(self.flags)): if self.flags[i].name == name: - return Reference(self.flags[i]) - else: + return self.flags[i] + else: for i in range(len(self.flags)): if self.flags[i].name == name and self.flags[i].type == type: - return Reference(self.flags[i]) - return None - - fn lookup_name(self, shorthand: String) -> Optional[String]: + return self.flags[i] + + raise Error("FlagNotFoundError: Could not find the following flag: " + name) + + fn lookup_name(self, shorthand: String) raises -> String: """Returns the name of a flag given its shorthand. Args: @@ -116,68 +116,69 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl for flag in self.flags: if flag[].shorthand and flag[].shorthand == shorthand: return flag[].name - return None + + raise Error("FlagNotFoundError: Could not find the following flag shorthand: " + shorthand) fn get_as[ - R: CollectionElement, transform: fn (flag_set: FlagSet, name: String) -> Optional[R] - ](self, name: String) -> Optional[R]: + R: CollectionElement, transform: fn (flag_set: FlagSet, name: String) raises -> R + ](self, name: String) raises -> R: return transform(self, name) - fn get_string(self, name: String) -> Optional[String]: + fn get_string(self, name: String) raises -> String: """Returns the value of a flag as a String. If it isn't set, then return the default value.""" - return self.get_as[R=String, transform=get_string](name) + return self.lookup(name, "String").value_or_default() - fn get_bool(self, name: String) -> Optional[Bool]: + fn get_bool(self, name: String) raises -> Bool: """Returns the value of a flag as a Bool. If it isn't set, then return the default value.""" - return self.get_as[R=Bool, transform=get_bool](name) + return string_to_bool(self.lookup(name, "Bool").value_or_default()) - fn get_int(self, name: String) -> Optional[Int]: + fn get_int(self, name: String, type: String = "Int") raises -> Int: """Returns the value of a flag as an Int. If it isn't set, then return the default value.""" - return self.get_as[R=Int, transform=get_int](name) + return atol(self.lookup(name, type).value_or_default()) - fn get_int8(self, name: String) -> Optional[Int8]: + fn get_int8(self, name: String) raises -> Int8: """Returns the value of a flag as a Int8. If it isn't set, then return the default value.""" - return self.get_as[R=Int8, transform=get_int8](name) + return Int8(self.get_int(name, "Int8")) - fn get_int16(self, name: String) -> Optional[Int16]: + fn get_int16(self, name: String) raises -> Int16: """Returns the value of a flag as a Int16. If it isn't set, then return the default value.""" - return self.get_as[R=Int16, transform=get_int16](name) + return Int16(self.get_int(name, "Int16")) - fn get_int32(self, name: String) -> Optional[Int32]: + fn get_int32(self, name: String) raises -> Int32: """Returns the value of a flag as a Int32. If it isn't set, then return the default value.""" - return self.get_as[R=Int32, transform=get_int32](name) + return Int32(self.get_int(name, "Int32")) - fn get_int64(self, name: String) -> Optional[Int64]: + fn get_int64(self, name: String) raises -> Int64: """Returns the value of a flag as a Int64. If it isn't set, then return the default value.""" - return self.get_as[R=Int64, transform=get_int64](name) + return Int64(self.get_int(name, "Int64")) - fn get_uint8(self, name: String) -> Optional[UInt8]: + fn get_uint8(self, name: String) raises -> UInt8: """Returns the value of a flag as a UInt8. If it isn't set, then return the default value.""" - return self.get_as[R=UInt8, transform=get_uint8](name) + return UInt8(self.get_int(name, "UInt8")) - fn get_uint16(self, name: String) -> Optional[UInt16]: + fn get_uint16(self, name: String) raises -> UInt16: """Returns the value of a flag as a UInt16. If it isn't set, then return the default value.""" - return self.get_as[R=UInt16, transform=get_uint16](name) + return UInt16(self.get_int(name, "UInt16")) - fn get_uint32(self, name: String) -> Optional[UInt32]: + fn get_uint32(self, name: String) raises -> UInt32: """Returns the value of a flag as a UInt32. If it isn't set, then return the default value.""" - return self.get_as[R=UInt32, transform=get_uint32](name) + return UInt32(self.get_int(name, "UInt32")) - fn get_uint64(self, name: String) -> Optional[UInt64]: + fn get_uint64(self, name: String) raises -> UInt64: """Returns the value of a flag as a UInt64. If it isn't set, then return the default value.""" - return self.get_as[R=UInt64, transform=get_uint64](name) + return UInt64(self.get_int(name, "UInt64")) - fn as_float16(self, name: String) -> Optional[Float16]: + fn as_float16(self, name: String) raises -> Float16: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float16, transform=as_float16](name) + return self.as_float64(name).cast[DType.float16]() - fn as_float32(self, name: String) -> Optional[Float32]: + fn as_float32(self, name: String) raises -> Float32: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float32, transform=as_float32](name) + return self.as_float64(name).cast[DType.float32]() - fn as_float64(self, name: String) -> Optional[Float64]: + fn as_float64(self, name: String) raises -> Float64: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float64, transform=as_float64](name) + return string_to_float(self.lookup(name, "Float64").value_or_default()) fn names(self) -> List[String]: """Returns a list of names of all flags in the flag set.""" @@ -202,7 +203,7 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl default: Bool = False, ) -> None: """Adds a `Bool` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Bool")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Bool")) fn string_flag( inout self, @@ -212,55 +213,55 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl default: String = "", ) -> None: """Adds a `String` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="String")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="String")) fn int_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int = 0) -> None: """Adds an `Int` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int")) fn int8_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int8 = 0) -> None: """Adds an `Int8` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int8")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int8")) fn int16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int16 = 0) -> None: """Adds an `Int16` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int16")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int16")) fn int32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int32 = 0) -> None: """Adds an `Int32` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int32")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int32")) fn int64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int64 = 0) -> None: """Adds an `Int64` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int64")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int64")) fn uint8_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt8 = 0) -> None: """Adds a `UInt8` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt8")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt8")) fn uint16_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt16 = 0) -> None: """Adds a `UInt16` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt16")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt16")) fn uint32_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt32 = 0) -> None: """Adds a `UInt32` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt32")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt32")) fn uint64_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt64 = 0) -> None: """Adds a `UInt64` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt64")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt64")) fn float16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float16 = 0) -> None: """Adds a `Float16` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float16")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Float16")) fn float32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float32 = 0) -> None: """Adds a `Float32` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float32")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Float32")) fn float64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float64 = 0) -> None: """Adds a `Float64` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float64")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Float64")) fn set_annotation(inout self, name: String, key: String, values: String) raises -> None: """Sets an annotation for a flag. @@ -270,17 +271,15 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl key: The key of the annotation. values: The values of the annotation. """ - var result = self.lookup(name) - if not result: - raise Error(String("FlagSet.set_annotation: Failed to find flag: {}.").format(name)) - # Annotation value can be a concatenated string of values. # Why? Because we can have multiple required groups of flags for example. # So each value of the list for the annotation can be a group of flag names. - if not result.value()[].annotations.get(key): - result.value()[].annotations[key] = List[String](values) - else: - result.value()[].annotations[key].extend(values) + try: + # TODO: remove running 2 lookups when ref can return a reference + # we can store as a var without copying the result. + self.lookup(name).annotations[key].extend(values) + except: + self.lookup(name).annotations[key] = List[String](values) fn set_required(inout self, name: String) raises -> None: """Sets a flag as required or not. @@ -303,9 +302,9 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl self.set_annotation(name, annotation_type, names) except e: print( - String( - "FlagSet.set_as: Failed to set flag, {}, with the following annotation: {}" - ).format(name, annotation_type), + String("FlagSet.set_as: Failed to set flag, {}, with the following annotation: {}").format( + name, annotation_type + ), file=2, ) raise e @@ -337,8 +336,11 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl @always_inline fn add_flag(flag: Flag) capturing -> None: - if not self.lookup(flag.name): - self.flags.append(flag) + try: + _ = self.lookup(flag.name) + except e: + if str(e).find("FlagNotFoundError") != -1: + self.flags.append(flag) new_set.visit_all[add_flag]() @@ -369,4 +371,4 @@ fn validate_required_flags(flags: FlagSet) raises -> None: flags.visit_all[check_required_flag]() if len(missing_flag_names) > 0: - raise Error("required flag(s) " + missing_flag_names.__str__() + " not set") + raise Error("Required flag(s): " + missing_flag_names.__str__() + " not set.") diff --git a/src/prism/transform.mojo b/src/prism/transform.mojo deleted file mode 100644 index 3b1ff30..0000000 --- a/src/prism/transform.mojo +++ /dev/null @@ -1,226 +0,0 @@ -from collections import Optional -from .flag_set import FlagSet -from .util import panic, string_to_bool, string_to_float - - -fn get_string(flag_set: FlagSet, name: String) -> Optional[String]: - """Returns the value of a flag as a String. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var flag = flag_set.lookup(name, "String") - if not flag: - return None - - fn get(value: String) -> String: - return value - - # TODO: inferring the return type in the parameter only works for String as of 24.5. - # Will switch the other transform functions in the future when it works. - return flag.value()[].get_with_transform[get]() - - -fn get_bool(flag_set: FlagSet, name: String) -> Optional[Bool]: - """Returns the value of a flag as a Bool. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var result = flag_set.lookup(name, "Bool") - if not result: - return None - - var flag = result.value() - if not flag[].value: - return string_to_bool(flag[].default) - - return string_to_bool(flag[].value.value()) - - -fn get_int(flag_set: FlagSet, name: String) -> Optional[Int]: - """Returns the value of a flag as an Int. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var result = flag_set.lookup(name, "Int") - if not result: - return None - - var flag = result.value() - - # TODO: I don't like this swallowing up a failure to convert to int. Maybe return a tuple of optional and error? - try: - if not flag[].value: - return atol(flag[].default) - - return atol(flag[].value.value()) - except: - return None - - -fn get_int8(flag_set: FlagSet, name: String) -> Optional[Int8]: - """Returns the value of a flag as a Int8. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int8(value.value()) - - -fn get_int16(flag_set: FlagSet, name: String) -> Optional[Int16]: - """Returns the value of a flag as a Int16. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int16(value.value()) - - -fn get_int32(flag_set: FlagSet, name: String) -> Optional[Int32]: - """Returns the value of a flag as a Int32. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int32(value.value()) - - -fn get_int64(flag_set: FlagSet, name: String) -> Optional[Int64]: - """Returns the value of a flag as a Int64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int64(value.value()) - - -fn get_uint8(flag_set: FlagSet, name: String) -> Optional[UInt8]: - """Returns the value of a flag as a UInt8. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt8(value.value()) - - -fn get_uint16(flag_set: FlagSet, name: String) -> Optional[UInt16]: - """Returns the value of a flag as a UInt16. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt16(value.value()) - - -fn get_uint32(flag_set: FlagSet, name: String) -> Optional[UInt32]: - """Returns the value of a flag as a UInt32. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt32(value.value()) - - -fn get_uint64(flag_set: FlagSet, name: String) -> Optional[UInt64]: - """Returns the value of a flag as a UInt64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt64(value.value()) - - -fn as_float16(flag_set: FlagSet, name: String) -> Optional[Float16]: - """Returns the value of a flag as a Float64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.as_float64(name) - if not value: - return None - - return value.value().cast[DType.float16]() - - -fn as_float32(flag_set: FlagSet, name: String) -> Optional[Float32]: - """Returns the value of a flag as a Float64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.as_float64(name) - if not value: - return None - - return value.value().cast[DType.float32]() - - -fn as_float64(flag_set: FlagSet, name: String) -> Optional[Float64]: - """Returns the value of a flag as a Float64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var result = flag_set.lookup(name, "Float64") - if not result: - return None - - var flag = result.value() - - # TODO: I don't like this swallowing up a failure to convert to int. Maybe return a tuple of optional and error? - try: - if not flag[].value: - return string_to_float(flag[].default) - - return string_to_float(flag[].value.value()) - except e: - return None diff --git a/test/test_args.mojo b/test/test_args.mojo index 84e8bd5..037c9b2 100644 --- a/test/test_args.mojo +++ b/test/test_args.mojo @@ -10,7 +10,7 @@ from prism.args import ( exact_args, range_args, match_all, - ArgValidator, + ArgValidatorFn, ) @@ -20,20 +20,20 @@ from prism.args import ( # TODO: renable these when we have assert raises in testing # def test_no_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var arc = Arc(cmd) # var result = no_args(arc, List[String]("abc")) # testing.assert_equal(result.value(), String("The command `root` does not take any arguments.")) # def test_valid_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = valid_args[List[String]("Pineapple")]()(Arc(cmd), List[String]("abc")) # testing.assert_equal(result.value(), "Invalid argument: `abc`, for the command `root`.") # def test_arbitrary_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = arbitrary_args(Arc(cmd), List[String]("abc", "blah", "blah")) # # If the result is anything but None, fail the test. @@ -42,33 +42,33 @@ from prism.args import ( # def test_minimum_n_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = minimum_n_args[3]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root` accepts at least 3 argument(s). Received: 2.") # def test_maximum_n_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = maximum_n_args[1]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root` accepts at most 1 argument(s). Received: 2.") # def test_exact_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = exact_args[1]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root` accepts exactly 1 argument(s). Received: 2.") # def test_range_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = range_args[0, 1]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root`, accepts between 0 to 1 argument(s). Received: 2.") # def test_match_all(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var args = List[String]("abc", "123") -# alias validators = List[ArgValidator]( +# alias validators = List[ArgValidatorFn]( # range_args[0, 1](), # valid_args[List[String]("Pineapple")]() # ) diff --git a/test/test_command.mojo b/test/test_command.mojo index 48f2ea5..428e927 100644 --- a/test/test_command.mojo +++ b/test/test_command.mojo @@ -8,15 +8,15 @@ def test_command_operations(): fn dummy(ctx: Context) -> None: return None - var cmd = Arc(Command(name="root", description="Base command.", run=dummy)) + var cmd = Arc(Command(name="root", usage="Base command.", run=dummy)) var flags = cmd[].flags.flags for flag in flags: testing.assert_equal(String("help"), flag[].name) - var child_cmd = Arc(Command(name="child", description="Child command.", run=dummy)) + var child_cmd = Arc(Command(name="child", usage="Child command.", run=dummy)) cmd[].add_subcommand(child_cmd) - child_cmd[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + child_cmd[].flags.get_string(name="color", shorthand="c", usage="Text color", default="#3464eb") testing.assert_equal(child_cmd[].full_name(), "root child") diff --git a/test/test_flags.mojo b/test/test_flags.mojo index 64406b6..68fe009 100644 --- a/test/test_flags.mojo +++ b/test/test_flags.mojo @@ -20,8 +20,8 @@ def test_string_to_float(): def test_get_flags(): var flag_set = FlagSet() - flag_set.string_flag("key", "description", "default") - flag_set.bool_flag("flag", "description", "False") + flag_set.get_string("key", "usage", "default") + flag_set.get_bool("flag", "usage", "False") var flags = List[String]("--key=value", "positional", "--flag") _ = flag_set.from_args(flags) @@ -31,8 +31,8 @@ def test_get_flags(): def test_parse_flag(): var flag_set = FlagSet() - flag_set.string_flag(name="key", usage="description", default="default") - flag_set.bool_flag(name="flag", usage="description", default=False) + flag_set.get_string(name="key", usage="usage", default="default") + flag_set.get_bool(name="flag", usage="usage", default=False) var parser = FlagParser() var name: String @@ -51,8 +51,8 @@ def test_parse_flag(): def test_parse_shorthand_flag(): var flag_set = FlagSet() - flag_set.string_flag(name="key", usage="description", default="default", shorthand="k") - flag_set.bool_flag(name="flag", usage="description", default=False, shorthand="f") + flag_set.get_string(name="key", usage="usage", default="default", shorthand="k") + flag_set.get_bool(name="flag", usage="usage", default=False, shorthand="f") var parser = FlagParser() var name: String From f68ebb9b1f0dbe3aa1e75804f6340e55c875195b Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Wed, 9 Oct 2024 15:39:18 -0500 Subject: [PATCH 3/4] fix shorthand flag parsing --- README.md | 18 ++++++++---------- src/prism/flag_parser.mojo | 7 +++++-- test/test_command.mojo | 2 +- test/test_flags.mojo | 20 ++++++++++---------- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index acc884d..1b43e14 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ If you have different flags that must be provided together (e.g. if they provide var print_tool = Arc(Command( name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") )) - print_tool[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + print_tool[].flags.uint32_flag(name="color", shorthand="c", usage="Text color", default=0x3464eb) print_tool[].flags.string_flag(name="formatting", shorthand="f", usage="Text formatting") print_tool[].mark_flags_required_together("color", "formatting") ``` @@ -246,13 +246,13 @@ Validation of positional arguments can be specified using the `arg_validator` fi - `exact_args[Int]` - report an error if there are not exactly N positional args. - `range_args[min, max]` - report an error if the number of args is not between min and max. - Content of the arguments: - - `only_valid_args` - report an error if there are any positional args not specified in the `valid_args` field of `Command`, which can optionally be set to a list of valid values for positional args. + - `valid_args` - report an error if there are any positional args not specified in the `valid_args` field of `Command`, which can optionally be set to a list of valid values for positional args. If `arg_validator` is undefined, it defaults to `arbitrary_args`. > NOTE: `match_all` is unstable at the moment. I will work on ironing it out in the near future. This most likely does not work. -Moreover, `match_all[arg_validators: List[ArgValidator]]` enables combining existing checks with arbitrary other checks. For instance, if you want to report an error if there are not exactly N positional args OR if there are any positional args that are not in the ValidArgs field of Command, you can call `match_all` on `exact_args` and `only_valid_args`, as shown below: +Moreover, `match_all[arg_validators: List[ArgValidator]]` enables combining existing checks with arbitrary other checks. For instance, if you want to report an error if there are not exactly N positional args OR if there are any positional args that are not in the ValidArgs field of Command, you can call `match_all` on `exact_args` and `valid_args`, as shown below: ```mojo fn test_match_all(): @@ -273,16 +273,15 @@ Commands are configured to accept a `--help` flag by default. This will print th ```mojo fn help_func(inout command: Arc[Command]) -> String: - return "" + return "My help function." fn main() -> None: var root = Command( name="hello", description="This is a dummy command!", run=test, + help=help_func ) - - var hello_command = Arc(Command(name="chromeria", description="This is a dummy command!", run=hello, help=help_func)) ``` ![Help](https://github.com/thatstoasty/prism/blob/main/doc/tapes/help.gif) @@ -293,11 +292,10 @@ fn main() -> None: ## TODO -### Documentation - ### Features -- Add find suggestion logic to `Command` struct. +- Add suggestion logic to `Command` struct. +- Autocomplete generation. - Enable usage function to return the results of a usage function upon calling wrong functions or commands. - Replace print usage with writers to enable stdout/stderr/file writing. - Update default help command to improve available commands and flags section. @@ -305,7 +303,7 @@ fn main() -> None: ### Improvements - Tree traversal improvements. -- Arc[Command] being passed to validators and command functions is marked as inout because the compiler complains about forming a reference to a borrowed register value. This is a temporary fix, I will try to get it back to a borrowed reference. +- `Arc[Command]` being passed to validators and command functions is marked as inout because the compiler complains about forming a reference to a borrowed register value. This is a temporary fix, I will try to get it back to a borrowed reference. - For now, help functions and arg validators will need to be set after the command is constructed. This is to help reduce cyclical dependencies, but I will work on a way to set these values in the constructor as the type system matures. ### Bugs diff --git a/src/prism/flag_parser.mojo b/src/prism/flag_parser.mojo index 302cdbe..4f94c16 100644 --- a/src/prism/flag_parser.mojo +++ b/src/prism/flag_parser.mojo @@ -83,9 +83,12 @@ struct FlagParser: raise Error("Command does not accept the shorthand flag supplied: " + shorthand) # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). - if flags.get_bool(name): + try: + _ = flags.lookup(name, "Bool") return name, String("True"), 1 - + except: + pass + if self.index + 1 >= len(arguments): raise Error("Flag `" + name + "` requires a value to be set but reached the end of arguments.") diff --git a/test/test_command.mojo b/test/test_command.mojo index 428e927..377fa6b 100644 --- a/test/test_command.mojo +++ b/test/test_command.mojo @@ -16,7 +16,7 @@ def test_command_operations(): var child_cmd = Arc(Command(name="child", usage="Child command.", run=dummy)) cmd[].add_subcommand(child_cmd) - child_cmd[].flags.get_string(name="color", shorthand="c", usage="Text color", default="#3464eb") + child_cmd[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") testing.assert_equal(child_cmd[].full_name(), "root child") diff --git a/test/test_flags.mojo b/test/test_flags.mojo index 68fe009..409d022 100644 --- a/test/test_flags.mojo +++ b/test/test_flags.mojo @@ -20,19 +20,19 @@ def test_string_to_float(): def test_get_flags(): var flag_set = FlagSet() - flag_set.get_string("key", "usage", "default") - flag_set.get_bool("flag", "usage", "False") + flag_set.string_flag("key", "usage", "default") + flag_set.bool_flag("flag", "usage", "False") var flags = List[String]("--key=value", "positional", "--flag") _ = flag_set.from_args(flags) - testing.assert_equal(flag_set.get_string("key").value(), "value") - testing.assert_equal(flag_set.get_bool("flag").value(), True) + testing.assert_equal(flag_set.get_string("key"), "value") + testing.assert_equal(flag_set.get_bool("flag"), True) def test_parse_flag(): var flag_set = FlagSet() - flag_set.get_string(name="key", usage="usage", default="default") - flag_set.get_bool(name="flag", usage="usage", default=False) + flag_set.string_flag(name="key", usage="usage", default="default") + flag_set.bool_flag(name="flag", usage="usage", default=False) var parser = FlagParser() var name: String @@ -51,19 +51,19 @@ def test_parse_flag(): def test_parse_shorthand_flag(): var flag_set = FlagSet() - flag_set.get_string(name="key", usage="usage", default="default", shorthand="k") - flag_set.get_bool(name="flag", usage="usage", default=False, shorthand="f") + flag_set.string_flag(name="key", usage="usage", default="default", shorthand="k") + flag_set.bool_flag(name="flag", usage="usage", default=False, shorthand="f") var parser = FlagParser() var name: String var value: String var increment_by: Int - name, value, increment_by = parser.parse_shorthand_flag(String("-k"), List[String]("-k", "value"), flag_set) + name, value, increment_by = parser.parse_shorthand_flag("-k", List[String]("-k", "value"), flag_set) testing.assert_equal(name, "key") testing.assert_equal(value, "value") testing.assert_equal(increment_by, 2) - name, value, increment_by = parser.parse_shorthand_flag(String("-k=value"), List[String]("-k=value"), flag_set) + name, value, increment_by = parser.parse_shorthand_flag("-k=value", List[String]("-k=value"), flag_set) testing.assert_equal(name, "key") testing.assert_equal(value, "value") testing.assert_equal(increment_by, 1) From 88f619af3935b71ca1a55970bbbaec6429b9f0b5 Mon Sep 17 00:00:00 2001 From: Mikhail Tavarez Date: Wed, 9 Oct 2024 15:40:39 -0500 Subject: [PATCH 4/4] update examples --- CHANGELOG.md | 5 +++++ examples/printer/mojoproject.toml | 2 +- examples/requests/mojoproject.toml | 2 +- mojoproject.toml | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 736989a..6e4db3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] - yyyy-mm-dd +## [0.2.0] - 2024-10-09 + +- Refactor flag get/parsing to use ref to simplify code and enable removal of transform module. +- Renamed some fields to make them more accurate in the Command struct. + ## [0.1.7] - 2024-10-09 - Fix command parsing bug. diff --git a/examples/printer/mojoproject.toml b/examples/printer/mojoproject.toml index 7fd4c2e..ee293a1 100644 --- a/examples/printer/mojoproject.toml +++ b/examples/printer/mojoproject.toml @@ -11,4 +11,4 @@ version = "0.1.0" [dependencies] max = ">=24.5.0,<25" mist = ">=0.1.8,<0.2" -prism = ">=0.1.7,<0.2" +prism = ">=0.2.0,<0.2" diff --git a/examples/requests/mojoproject.toml b/examples/requests/mojoproject.toml index 7ce8840..4bc55b8 100644 --- a/examples/requests/mojoproject.toml +++ b/examples/requests/mojoproject.toml @@ -10,5 +10,5 @@ version = "0.1.0" [dependencies] max = ">=24.5.0,<25" -prism = ">=0.1.7,<0.2" +prism = ">=0.2.0,<0.2" requests = ">=2.32.3,<3" diff --git a/mojoproject.toml b/mojoproject.toml index ede4808..3691ffa 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -4,7 +4,7 @@ channels = ["https://repo.prefix.dev/mojo-community", "conda-forge", "https://co description = "A Budding CLI Library!" name = "prism" platforms = ["osx-arm64", "linux-64"] -version = "0.1.7" +version = "0.2.0" license = "MIT" license-file = "LICENSE" homepage = "https://github.com/thatstoasty/prism"