diff --git a/examples/aliases.mojo b/examples/aliases.mojo index 4518303..898c213 100644 --- a/examples/aliases.mojo +++ b/examples/aliases.mojo @@ -13,13 +13,13 @@ fn tool_func(ctx: Context) -> None: fn main() -> None: var root = Command( name="my", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) var print_tool = Arc( Command( - name="tool", description="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") + name="tool", usage="This is a dummy command!", run=tool_func, aliases=List[String]("object", "thing") ) ) diff --git a/examples/arg_validators.mojo b/examples/arg_validators.mojo index 66d1397..0c7ba9f 100644 --- a/examples/arg_validators.mojo +++ b/examples/arg_validators.mojo @@ -23,32 +23,33 @@ fn hello(ctx: Context) -> None: fn main() -> None: var root = Command( name="hello", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) - var no_args_command = Arc(Command(name="no_args", description="This is a dummy command!", run=hello)) + var no_args_command = Arc(Command(name="no_args", usage="This is a dummy command!", run=hello)) no_args_command[].arg_validator = no_args var valid_args_command = Arc( Command( name="valid_args", - description="This is a dummy command!", + usage="This is a dummy command!", run=hello, + valid_args=List[String]("Pineapple") ) ) valid_args_command[].arg_validator = valid_args - var minimum_n_args_command = Arc(Command(name="minimum_n_args", description="This is a dummy command!", run=hello)) + var minimum_n_args_command = Arc(Command(name="minimum_n_args", usage="This is a dummy command!", run=hello)) minimum_n_args_command[].arg_validator = minimum_n_args[4]() - var maximum_n_args_command = Arc(Command(name="maximum_n_args", description="This is a dummy command!", run=hello)) + var maximum_n_args_command = Arc(Command(name="maximum_n_args", usage="This is a dummy command!", run=hello)) maximum_n_args_command[].arg_validator = maximum_n_args[1]() - var exact_args_command = Arc(Command(name="exact_args", description="This is a dummy command!", run=hello)) + var exact_args_command = Arc(Command(name="exact_args", usage="This is a dummy command!", run=hello)) exact_args_command[].arg_validator = exact_args[1]() - var range_args_command = Arc(Command(name="range_args", description="This is a dummy command!", run=hello)) + var range_args_command = Arc(Command(name="range_args", usage="This is a dummy command!", run=hello)) range_args_command[].arg_validator = range_args[0, 1]() root.add_subcommand(no_args_command) diff --git a/examples/chromeria.mojo b/examples/chromeria.mojo index 3aeffb8..d95f3d4 100644 --- a/examples/chromeria.mojo +++ b/examples/chromeria.mojo @@ -13,11 +13,11 @@ fn hello(ctx: Context) -> None: fn main() -> None: var root = Command( name="hello", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) - var hello_command = Arc(Command(name="chromeria", description="This is a dummy command!", run=hello)) + var hello_command = Arc(Command(name="chromeria", usage="This is a dummy command!", run=hello)) root.add_subcommand(hello_command) root.execute() diff --git a/examples/fg_child.mojo b/examples/fg_child.mojo index f9f7eb3..982e058 100644 --- a/examples/fg_child.mojo +++ b/examples/fg_child.mojo @@ -13,7 +13,7 @@ fn tool_func(ctx: Context) -> None: fn main() -> None: var root = Command( name="my", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") @@ -21,7 +21,7 @@ fn main() -> None: root.persistent_flags.string_flag(name="port", shorthand="p", usage="Port") root.mark_persistent_flag_required("required") - var print_tool = Arc(Command(name="tool", description="This is a dummy command!", run=tool_func)) + var print_tool = Arc(Command(name="tool", usage="This is a dummy command!", run=tool_func)) print_tool[].flags.bool_flag(name="also", shorthand="a", usage="Also always required.") print_tool[].flags.string_flag(name="uri", shorthand="u", usage="URI") root.add_subcommand(print_tool) diff --git a/examples/fg_parent.mojo b/examples/fg_parent.mojo index 06cf794..373664d 100644 --- a/examples/fg_parent.mojo +++ b/examples/fg_parent.mojo @@ -2,15 +2,15 @@ from memory import Arc from prism import Command, Context -fn test(ctx: Context) -> None: +fn test(ctx: Context) raises -> None: var host = ctx.command[].flags.get_string("host") var port = ctx.command[].flags.get_string("port") var uri = ctx.command[].flags.get_string("uri") - if uri: - print("URI: ", uri.value()) + if uri != "": + print("URI:", uri) else: - print(host.value(), ":", port.value()) + print(host + ":" + port) fn tool_func(ctx: Context) -> None: @@ -20,8 +20,8 @@ fn tool_func(ctx: Context) -> None: fn main() -> None: var root = Command( name="my", - description="This is a dummy command!", - run=test, + usage="This is a dummy command!", + raising_run=test, ) root.persistent_flags.bool_flag(name="required", shorthand="r", usage="Always required.") root.persistent_flags.string_flag(name="host", shorthand="h", usage="Host") diff --git a/examples/hello_world.mojo b/examples/hello_world.mojo index c643fd4..cf46691 100644 --- a/examples/hello_world.mojo +++ b/examples/hello_world.mojo @@ -15,7 +15,7 @@ fn build_printer_command() -> Arc[Command]: var cmd = Arc( Command( name="printer", - description="Print the first arg.", + usage="Print the first arg.", run=printer, ) ) @@ -42,7 +42,7 @@ fn build_say_command() -> Arc[Command]: return Arc( Command( name="say", - description="Say something to someone", + usage="Say something to someone", run=say, ) ) @@ -52,7 +52,7 @@ fn build_hello_command() -> Arc[Command]: var cmd = Arc( Command( name="hello", - description="Say hello to someone", + usage="Say hello to someone", run=say_hello, ) ) @@ -63,7 +63,7 @@ fn build_goodbye_command() -> Arc[Command]: var cmd = Arc( Command( name="goodbye", - description="Say goodbye to someone", + usage="Say goodbye to someone", run=say_goodbye, ) ) @@ -71,7 +71,10 @@ fn build_goodbye_command() -> Arc[Command]: fn test(ctx: Context) -> None: - print(ctx.command[].flags.get_string("env").value()) + try: + print(ctx.command[].flags.get_string("env")) + except: + print("No env flag provided.") for item in ctx.command[].flags.flags: if item[].value: print(item[].name, item[].value.value()) @@ -84,7 +87,7 @@ fn test(ctx: Context) -> None: fn main() -> None: var root = Command( name="tones", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) root.flags.string_flag(name="env", shorthand="e", usage="Environment.", default="") diff --git a/examples/printer/mojoproject.toml b/examples/printer/mojoproject.toml index 04d0cd8..7fd4c2e 100644 --- a/examples/printer/mojoproject.toml +++ b/examples/printer/mojoproject.toml @@ -11,4 +11,4 @@ version = "0.1.0" [dependencies] max = ">=24.5.0,<25" mist = ">=0.1.8,<0.2" -prism = ">=0.1.6,<0.2" +prism = ">=0.1.7,<0.2" diff --git a/examples/printer/printer.mojo b/examples/printer/printer.mojo index e3cc105..59de23b 100644 --- a/examples/printer/printer.mojo +++ b/examples/printer/printer.mojo @@ -1,35 +1,25 @@ from memory import Arc from prism import Command, Context, exact_args -from mist import Style +import mist - -fn printer(ctx: Context) -> None: +fn printer(ctx: Context) raises -> None: if len(ctx.args) <= 0: print("No text to print! Pass in some text as a positional argument.") return None var color = ctx.command[].flags.get_uint32("color") var formatting = ctx.command[].flags.get_string("formatting") - var style = Style() - - if not color: - color = 0xFFFFFF - if not formatting: - formatting = String("") - - if color: - style = style.foreground(color.value()) + var style = mist.Style().foreground(color) - var formatting_value = formatting.or_else("") - if formatting_value == "": + if formatting == "": print(style.render(ctx.args[0])) return None - if formatting.value() == "bold": + if formatting == "bold": style = style.bold() - elif formatting.value() == "underline": + elif formatting == "underline": style = style.underline() - elif formatting.value() == "italic": + elif formatting == "italic": style = style.italic() print(style.render(ctx.args[0])) @@ -49,8 +39,8 @@ fn post_hook(ctx: Context) -> None: fn main() -> None: var root = Command( name="printer", - description="Base command.", - run=printer, + usage="Base command.", + raising_run=printer, pre_run=pre_hook, post_run=post_hook, ) diff --git a/examples/requests/mojoproject.toml b/examples/requests/mojoproject.toml index 814e96c..7ce8840 100644 --- a/examples/requests/mojoproject.toml +++ b/examples/requests/mojoproject.toml @@ -10,5 +10,5 @@ version = "0.1.0" [dependencies] max = ">=24.5.0,<25" -prism = ">=0.1.6,<0.2" +prism = ">=0.1.7,<0.2" requests = ">=2.32.3,<3" diff --git a/examples/requests/nested.mojo b/examples/requests/nested.mojo index 0be5fbf..b255ada 100644 --- a/examples/requests/nested.mojo +++ b/examples/requests/nested.mojo @@ -16,7 +16,7 @@ fn print_information(ctx: Context) -> None: fn get_cat_fact(ctx: Context) raises -> None: var flags = ctx.command[].flags var lover = flags.get_bool("lover") - if lover and lover.value(): + if lover: print("Hello fellow cat lover!") var requests = Python.import_module("requests") @@ -29,7 +29,7 @@ fn get_cat_fact(ctx: Context) raises -> None: if not count: raise Error("Count flag was not found.") - for _ in range(count.value()): + for _ in range(count): var response = requests.get(url) # Check if the request was successful (status code 200) @@ -57,12 +57,12 @@ fn get_dog_breeds(ctx: Context) raises -> None: fn main() -> None: - var root = Command(name="nested", description="Base command.", run=base) + var root = Command(name="nested", usage="Base command.", run=base) var get_command = Arc( Command( name="get", - description="Base command for getting some data.", + usage="Base command for getting some data.", run=print_information, ) ) @@ -70,8 +70,8 @@ fn main() -> None: var cat_command = Arc( Command( name="cat", - description="Get some cat facts!", - erroring_run=get_cat_fact, + usage="Get some cat facts!", + raising_run=get_cat_fact, ) ) cat_command[].flags.int_flag(name="count", shorthand="c", usage="Number of facts to get.", default=1) @@ -80,8 +80,8 @@ fn main() -> None: var dog_command = Arc( Command( name="dog", - description="Get some dog breeds!", - erroring_run=get_dog_breeds, + usage="Get some dog breeds!", + raising_run=get_dog_breeds, ) ) diff --git a/examples/requests/persistent_flags.mojo b/examples/requests/persistent_flags.mojo index 09484c0..64da5d3 100644 --- a/examples/requests/persistent_flags.mojo +++ b/examples/requests/persistent_flags.mojo @@ -16,7 +16,7 @@ fn print_information(ctx: Context) -> None: fn get_cat_fact(ctx: Context) raises -> None: var flags = ctx.command[].flags var lover = flags.get_bool("lover") - if lover and lover.value(): + if lover: print("Hello fellow cat lover!") var requests = Python.import_module("requests") @@ -29,7 +29,7 @@ fn get_cat_fact(ctx: Context) raises -> None: if not count: raise Error("Count flag was not found.") - for _ in range(count.value()): + for _ in range(count): var response = requests.get(url) # Check if the request was successful (status code 200) @@ -43,7 +43,7 @@ fn get_cat_fact(ctx: Context) raises -> None: fn get_dog_breeds(ctx: Context) raises -> None: var flags = ctx.command[].flags var lover = flags.get_bool("lover") - if lover and lover.value(): + if lover: print("Hello fellow dog lover!") var requests = Python.import_module("requests") @@ -69,12 +69,12 @@ fn post_hook(ctx: Context) -> None: fn main() -> None: - var root = Command(name="nested", description="Base command.", run=base) + var root = Command(name="nested", usage="Base command.", run=base) var get_command = Arc( Command( name="get", - description="Base command for getting some data.", + usage="Base command for getting some data.", run=print_information, persistent_pre_run=pre_hook, persistent_post_run=post_hook, @@ -85,8 +85,8 @@ fn main() -> None: var cat_command = Arc( Command( name="cat", - description="Get some cat facts!", - erroring_run=get_cat_fact, + usage="Get some cat facts!", + raising_run=get_cat_fact, ) ) cat_command[].flags.int_flag(name="count", shorthand="c", usage="Number of facts to get.") @@ -94,8 +94,8 @@ fn main() -> None: var dog_command = Arc( Command( name="dog", - description="Get some dog breeds!", - erroring_run=get_dog_breeds, + usage="Get some dog breeds!", + raising_run=get_dog_breeds, ) ) diff --git a/scripts/examples.sh b/scripts/examples.sh index 355fabc..ce63c07 100755 --- a/scripts/examples.sh +++ b/scripts/examples.sh @@ -16,16 +16,16 @@ magic run mojo build $TEMP_DIR/arg_validators.mojo -o $TEMP_DIR/validators echo "[INFO] Running examples..." # Need to run these first examples as part of a mojo project as they have external dependencies. # printer is a portable binary, but nested and persistent_flags are not because they depend on a python library. -cd examples/printer -magic run mojo build printer.mojo -./printer "sample-text" --formatting=underline +# cd examples/printer +# magic run mojo build printer.mojo +# ./printer "sample-text" --formatting=underline -cd ../requests -magic run mojo build nested.mojo -magic run nested get cat --count 3 -l -magic run mojo build persistent_flags.mojo -o persistent -magic run persistent get cat --count 2 --lover -magic run persistent get dog +# cd ../requests +# magic run mojo build nested.mojo +# magic run nested get cat --count 3 -l +# magic run mojo build persistent_flags.mojo -o persistent +# magic run persistent get cat --count 2 --lover +# magic run persistent get dog cd ../.. diff --git a/src/prism/__init__.mojo b/src/prism/__init__.mojo index 5161851..4ec8cb5 100644 --- a/src/prism/__init__.mojo +++ b/src/prism/__init__.mojo @@ -1,7 +1,7 @@ from .command import ( Command, - CommandFunction, - ArgValidator, + CmdFn, + ArgValidatorFn, ) from .args import no_args, valid_args, arbitrary_args, minimum_n_args, maximum_n_args, exact_args, range_args, match_all from .flag import Flag diff --git a/src/prism/args.mojo b/src/prism/args.mojo index b8fc2bf..fd4d215 100644 --- a/src/prism/args.mojo +++ b/src/prism/args.mojo @@ -1,7 +1,7 @@ from memory.arc import Arc from collections.optional import Optional import gojo.fmt -from .command import ArgValidator +from .command import ArgValidatorFn from .context import Context @@ -24,7 +24,7 @@ fn arbitrary_args(ctx: Context) raises -> None: return None -fn minimum_n_args[n: Int]() -> ArgValidator: +fn minimum_n_args[n: Int]() -> ArgValidatorFn: """Returns an error if there is not at least n arguments. Params: @@ -48,7 +48,7 @@ fn minimum_n_args[n: Int]() -> ArgValidator: return less_than_n_args -fn maximum_n_args[n: Int]() -> ArgValidator: +fn maximum_n_args[n: Int]() -> ArgValidatorFn: """Returns an error if there are more than n arguments. Params: @@ -72,7 +72,7 @@ fn maximum_n_args[n: Int]() -> ArgValidator: return more_than_n_args -fn exact_args[n: Int]() -> ArgValidator: +fn exact_args[n: Int]() -> ArgValidatorFn: """Returns an error if there are not exactly n arguments. Params: @@ -108,7 +108,7 @@ fn valid_args(ctx: Context) raises -> None: raise Error(fmt.sprintf("Invalid argument: `%s`, for the command `%s`.", arg[], ctx.command[].name)) -fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: +fn range_args[minimum: Int, maximum: Int]() -> ArgValidatorFn: """Returns an error if there are not exactly n arguments. Params: @@ -135,11 +135,11 @@ fn range_args[minimum: Int, maximum: Int]() -> ArgValidator: # TODO: Having some issues with varadic list of functions, so using List for now. -fn match_all[arg_validators: List[ArgValidator]]() -> ArgValidator: +fn match_all[arg_validators: List[ArgValidatorFn]]() -> ArgValidatorFn: """Returns an error if any of the arg_validators return an error. Params: - arg_validators: A list of ArgValidator functions that check the arguments. + arg_validators: A list of ArgValidatorFn functions that check the arguments. Returns: A function that checks all the arguments using the arg_validators list.. diff --git a/src/prism/command.mojo b/src/prism/command.mojo index 84f8960..b502eec 100644 --- a/src/prism/command.mojo +++ b/src/prism/command.mojo @@ -42,7 +42,7 @@ fn default_help(inout command: Arc[Command]) -> String: """Prints the help information for the command. TODO: Add padding for commands, options, and aliases. """ - var description_style = mog.Style().border(mog.HIDDEN_BORDER) + var usage_style = mog.Style().border(mog.HIDDEN_BORDER) var border_style = mog.Style().border(mog.ROUNDED_BORDER).border_foreground(mog.Color(0x383838)).padding(0, 1) var option_style = mog.Style().foreground(mog.Color(0x81C8BE)) var bold_style = mog.Style().bold() @@ -58,7 +58,7 @@ fn default_help(inout command: Arc[Command]) -> String: _ = builder.write_string(" COMMAND") _ = builder.write_string(" [ARGS]...") - var description = description_style.render(mog.join_vertical(mog.left, str(builder), "\n", cmd[].description)) + var usage = usage_style.render(mog.join_vertical(mog.left, str(builder), "\n", cmd[].usage)) builder = StringBuilder() if cmd[].flags.flags: @@ -73,9 +73,7 @@ fn default_help(inout command: Arc[Command]) -> String: _ = builder.write_string(bold_style.render("Commands")) for i in range(len(cmd[].children)): _ = builder.write_string( - fmt.sprintf( - "\n%s %s", option_style.render(cmd[].children[i][].name), cmd[].children[i][].description - ) + fmt.sprintf("\n%s %s", option_style.render(cmd[].children[i][].name), cmd[].children[i][].usage) ) if i == len(cmd[].children) - 1: @@ -86,16 +84,16 @@ fn default_help(inout command: Arc[Command]) -> String: _ = builder.write_string(fmt.sprintf("\n%s", option_style.render(cmd[].aliases.__str__()))) var commands = border_style.render(str(builder)) - return mog.join_vertical(mog.left, description, options, commands) + return mog.join_vertical(mog.left, usage, options, commands) -alias CommandFunction = fn (ctx: Context) -> None +alias CmdFn = fn (ctx: Context) -> None """The function for a command to run.""" -alias CommandFunctionErr = fn (ctx: Context) raises -> None +alias RaisingCmdFn = fn (ctx: Context) raises -> None """The function for a command to run that can error.""" -alias HelpFunction = fn (inout command: Arc[Command]) -> String -"""The function for a help function.""" -alias ArgValidator = fn (ctx: Context) raises -> None +alias HelpFn = fn (inout command: Arc[Command]) -> String +"""The function to generate help output.""" +alias ArgValidatorFn = fn (ctx: Context) raises -> None """The function for an argument validator.""" alias ParentVisitorFn = fn (parent: Arc[Command]) capturing -> None """The function for visiting parents of a command.""" @@ -120,7 +118,7 @@ struct Command(CollectionElement): fn main(): var command = Command( name="hello", - description="This is a dummy command!", + usage="This is a dummy command!", run=test, ) command.execute() @@ -135,38 +133,38 @@ struct Command(CollectionElement): var name: String """The name of the command.""" - var description: String + var usage: String """Description of the command.""" var aliases: List[String] """Aliases that can be used instead of the first word in name.""" - var help: HelpFunction + var help: HelpFn """Generates help text.""" - var pre_run: Optional[CommandFunction] + var pre_run: Optional[CmdFn] """A function to run before the run function is executed.""" - var run: Optional[CommandFunction] + var run: Optional[CmdFn] """A function to run when the command is executed.""" - var post_run: Optional[CommandFunction] + var post_run: Optional[CmdFn] """A function to run after the run function is executed.""" - var erroring_pre_run: Optional[CommandFunctionErr] + var raising_pre_run: Optional[RaisingCmdFn] """A raising function to run before the run function is executed.""" - var erroring_run: Optional[CommandFunctionErr] + var raising_run: Optional[RaisingCmdFn] """A raising function to run when the command is executed.""" - var erroring_post_run: Optional[CommandFunctionErr] + var raising_post_run: Optional[RaisingCmdFn] """A raising function to run after the run function is executed.""" - var persistent_pre_run: Optional[CommandFunction] + var persistent_pre_run: Optional[CmdFn] """A function to run before the run function is executed. This persists to children.""" - var persistent_post_run: Optional[CommandFunction] + var persistent_post_run: Optional[CmdFn] """A function to run after the run function is executed. This persists to children.""" - var persistent_erroring_pre_run: Optional[CommandFunctionErr] + var persistent_raising_pre_run: Optional[RaisingCmdFn] """A raising function to run before the run function is executed. This persists to children.""" - var persistent_erroring_post_run: Optional[CommandFunctionErr] + var persistent_raising_post_run: Optional[RaisingCmdFn] """A raising function to run after the run function is executed. This persists to children.""" - var arg_validator: ArgValidator + var arg_validator: ArgValidatorFn """Function to validate arguments passed to the command.""" var valid_args: List[String] """Valid arguments for the command.""" @@ -192,43 +190,43 @@ struct Command(CollectionElement): fn __init__( inout self, name: String, - description: String, + usage: String, aliases: List[String] = List[String](), valid_args: List[String] = List[String](), - run: Optional[CommandFunction] = None, - pre_run: Optional[CommandFunction] = None, - post_run: Optional[CommandFunction] = None, - erroring_run: Optional[CommandFunctionErr] = None, - erroring_pre_run: Optional[CommandFunctionErr] = None, - erroring_post_run: Optional[CommandFunctionErr] = None, - persistent_pre_run: Optional[CommandFunction] = None, - persistent_post_run: Optional[CommandFunction] = None, - persistent_erroring_pre_run: Optional[CommandFunctionErr] = None, - persistent_erroring_post_run: Optional[CommandFunctionErr] = None, + run: Optional[CmdFn] = None, + pre_run: Optional[CmdFn] = None, + post_run: Optional[CmdFn] = None, + raising_run: Optional[RaisingCmdFn] = None, + raising_pre_run: Optional[RaisingCmdFn] = None, + raising_post_run: Optional[RaisingCmdFn] = None, + persistent_pre_run: Optional[CmdFn] = None, + persistent_post_run: Optional[CmdFn] = None, + persistent_raising_pre_run: Optional[RaisingCmdFn] = None, + persistent_raising_post_run: Optional[RaisingCmdFn] = None, ): """ Args: name: The name of the command. - description: The description of the command. + usage: The usage of the command. arg_validator: The function to validate the arguments passed to the command. valid_args: The valid arguments for the command. run: The function to run when the command is executed. pre_run: The function to run before the command is executed. post_run: The function to run after the command is executed. - erroring_run: The function to run when the command is executed that returns an error. - erroring_pre_run: The function to run before the command is executed that returns an error. - erroring_post_run: The function to run after the command is executed that returns an error. + raising_run: The function to run when the command is executed that returns an error. + raising_pre_run: The function to run before the command is executed that returns an error. + raising_post_run: The function to run after the command is executed that returns an error. persisting_pre_run: The function to run before the command is executed. This persists to children. persisting_post_run: The function to run after the command is executed. This persists to children. - persisting_erroring_pre_run: The function to run before the command is executed that returns an error. This persists to children. - persisting_erroring_post_run: The function to run after the command is executed that returns an error. This persists to children. + persisting_raising_pre_run: The function to run before the command is executed that returns an error. This persists to children. + persisting_raising_post_run: The function to run after the command is executed that returns an error. This persists to children. help: The function to generate help text for the command. """ - if not run and not erroring_run: - panic("A command must have a run or erroring_run function.") + if not run and not raising_run: + panic("A command must have a run or raising_run function.") self.name = name - self.description = description + self.usage = usage self.aliases = aliases self.help = default_help @@ -237,14 +235,14 @@ struct Command(CollectionElement): self.run = run self.post_run = post_run - self.erroring_pre_run = erroring_pre_run - self.erroring_run = erroring_run - self.erroring_post_run = erroring_post_run + self.raising_pre_run = raising_pre_run + self.raising_run = raising_run + self.raising_post_run = raising_post_run self.persistent_pre_run = persistent_pre_run self.persistent_post_run = persistent_post_run - self.persistent_erroring_pre_run = persistent_erroring_pre_run - self.persistent_erroring_post_run = persistent_erroring_post_run + self.persistent_raising_pre_run = persistent_raising_pre_run + self.persistent_raising_post_run = persistent_raising_post_run self.arg_validator = arbitrary_args self.valid_args = valid_args @@ -261,7 +259,7 @@ struct Command(CollectionElement): fn __moveinit__(inout self, owned existing: Self): self.name = existing.name^ - self.description = existing.description^ + self.usage = existing.usage^ self.aliases = existing.aliases^ self.help = existing.help @@ -270,14 +268,14 @@ struct Command(CollectionElement): self.run = existing.run^ self.post_run = existing.post_run^ - self.erroring_pre_run = existing.erroring_pre_run^ - self.erroring_run = existing.erroring_run^ - self.erroring_post_run = existing.erroring_post_run^ + self.raising_pre_run = existing.raising_pre_run^ + self.raising_run = existing.raising_run^ + self.raising_post_run = existing.raising_post_run^ self.persistent_pre_run = existing.persistent_pre_run^ self.persistent_post_run = existing.persistent_post_run^ - self.persistent_erroring_pre_run = existing.persistent_erroring_pre_run^ - self.persistent_erroring_post_run = existing.persistent_erroring_post_run^ + self.persistent_raising_pre_run = existing.persistent_raising_pre_run^ + self.persistent_raising_post_run = existing.persistent_raising_post_run^ self.arg_validator = existing.arg_validator self.valid_args = existing.valid_args^ @@ -311,8 +309,8 @@ struct Command(CollectionElement): writer.write("Command(Name: ") writer.write(self.name) - writer.write(", Description: ") - writer.write(self.description) + writer.write(", usage: ") + writer.write(self.usage) if self.aliases: writer.write(", Aliases: ") @@ -340,34 +338,36 @@ struct Command(CollectionElement): return self.parent[0][].root() return self - - fn _parse_command(self, command: Self, arg: String, children: List[Arc[Self]], inout leftover_start: Int) -> (Self, List[Arc[Self]]): + + fn _parse_command( + self, command: Self, arg: String, children: List[Arc[Self]], inout leftover_start: Int + ) -> (Self, List[Arc[Self]]): for command_ref in children: if command_ref[][].name == arg or arg in command_ref[][].aliases: leftover_start += 1 return command_ref[][], command_ref[][].children - + return command, children fn _parse_command_from_args(self, args: List[String]) -> (Self, List[String]): # If there's no children, then the root command is used. if not self.children or not args: return self, args - + var command = self var children = self.children var leftover_start = 0 # Start at 1 to start slice at the first remaining arg, not the last child command. for arg in args: command, children = self._parse_command(command, arg[], children, leftover_start) - + if leftover_start == 0: return self, args - + # If the there are more or equivalent args to the index, then there are remaining args to pass to the command. var remaining_args = List[String]() if len(args) >= leftover_start: - remaining_args = args[leftover_start:len(args)] + remaining_args = args[leftover_start : len(args)] return command, remaining_args @@ -376,8 +376,8 @@ struct Command(CollectionElement): try: # Run the persistent pre-run hooks. for parent in parents: - if parent[][].persistent_erroring_pre_run: - parent[][].persistent_erroring_pre_run.value()(ctx) + if parent[][].persistent_raising_pre_run: + parent[][].persistent_raising_pre_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: @@ -393,8 +393,8 @@ struct Command(CollectionElement): # Run the pre-run hooks. if ctx.command[].pre_run: ctx.command[].pre_run.value()(ctx) - elif ctx.command[].erroring_pre_run: - ctx.command[].erroring_pre_run.value()(ctx) + elif ctx.command[].raising_pre_run: + ctx.command[].raising_pre_run.value()(ctx) except e: print("Failed to run pre-run hooks for command: " + ctx.command[].name) raise e @@ -404,8 +404,8 @@ struct Command(CollectionElement): try: # Run the persistent post-run hooks. for parent in parents: - if parent[][].persistent_erroring_post_run: - parent[][].persistent_erroring_post_run.value()(ctx) + if parent[][].persistent_raising_post_run: + parent[][].persistent_raising_post_run.value()(ctx) @parameter if not ENABLE_TRAVERSE_RUN_HOOKS: @@ -421,8 +421,8 @@ struct Command(CollectionElement): # Run the post-run hooks. if ctx.command[].post_run: ctx.command[].post_run.value()(ctx) - elif ctx.command[].erroring_post_run: - ctx.command[].erroring_post_run.value()(ctx) + elif ctx.command[].raising_post_run: + ctx.command[].raising_post_run.value()(ctx) except e: print("Failed to run post-run hooks for command: " + ctx.command[].name, file=2) raise e @@ -447,6 +447,7 @@ struct Command(CollectionElement): # Add all parents to the list to check if they have persistent pre/post hooks. var parents = List[Arc[Self]]() + @parameter fn append_parents(parent: Arc[Self]) capturing -> None: parents.append(parent) @@ -459,20 +460,17 @@ struct Command(CollectionElement): if ENABLE_TRAVERSE_RUN_HOOKS: parents.reverse() - # Get the flags for the command to be executed. try: + # Get the flags for the command to be executed. remaining_args = command.flags.from_args(remaining_args) - except e: - panic(e) - # Check if the help flag was passed - var command_ref = Arc(command) - var help_passed = command.flags.get_bool("help") - if help_passed.value() == True: - print(command.help(command_ref)) - return None + # Check if the help flag was passed + var help_passed = command.flags.get_bool("help") + var command_ref = Arc(command) + if help_passed == True: + print(command.help(command_ref)) + return None - try: # Validate individual required flags (eg: flag is required) validate_required_flags(command.flags) @@ -488,7 +486,7 @@ struct Command(CollectionElement): if command.run: command.run.value()(ctx) else: - command.erroring_run.value()(ctx) + command.raising_run.value()(ctx) self._execute_post_run_hooks(ctx, parents) except e: panic(e) diff --git a/src/prism/context.mojo b/src/prism/context.mojo index dde09b6..9398374 100644 --- a/src/prism/context.mojo +++ b/src/prism/context.mojo @@ -8,4 +8,4 @@ struct Context: def __init__(inout self, command: Arc[Command], args: List[String]) -> None: self.command = command - self.args = args \ No newline at end of file + self.args = args diff --git a/src/prism/flag.mojo b/src/prism/flag.mojo index 43c849c..c1ec162 100644 --- a/src/prism/flag.mojo +++ b/src/prism/flag.mojo @@ -120,7 +120,7 @@ struct Flag(RepresentableCollectionElement, Stringable, Formattable): self.value = value self.changed = True - fn get_with_transform[T: CollectionElement, //, transform: fn (value: String) -> T](self) -> Optional[T]: + fn get_with_transform[T: CollectionElement, //, transform: fn (value: String) -> T](self) -> T: """Returns the value of the flag with a transformation applied to it. Params: @@ -132,3 +132,11 @@ struct Flag(RepresentableCollectionElement, Stringable, Formattable): if self.value: return transform(self.value.value()) return transform(self.default) + + fn value_or_default(self) -> String: + """Returns the value of the flag or the default value if it isn't set. + + Returns: + The value of the flag or the default value. + """ + return self.value.or_else(self.default) diff --git a/src/prism/flag_group.mojo b/src/prism/flag_group.mojo index e9de762..6dc514e 100644 --- a/src/prism/flag_group.mojo +++ b/src/prism/flag_group.mojo @@ -7,7 +7,9 @@ from gojo import fmt fn has_all_flags(flags: FlagSet, flag_names: List[String]) -> Bool: for name in flag_names: - if not flags.lookup(name[]): + try: + _ = flags.lookup(name[]) + except: return False return True @@ -40,7 +42,7 @@ fn process_flag_for_group_annotation( except e: raise Error( String( - "process_flag_for_group_annotation: Failed to set group status for annotation {}: {}" + "process_flag_for_group_annotation: Failed to set group status for annotation {}: {}." ).format(annotation, str(e)) ) @@ -71,7 +73,7 @@ fn validate_required_flag_group(data: Dict[String, Dict[String, Bool]]) -> None: panic( fmt.sprintf( - "if any flags in the group, %s, are set they must all be set; missing %s", + "If any flags in the group, %s, are set they must all be set; missing %s.", keys.__str__(), unset.__str__(), ) @@ -101,7 +103,7 @@ fn validate_one_required_flag_group(data: Dict[String, Dict[String, Bool]]) -> N for key in pair[].value.keys(): keys.append(key[]) - panic(fmt.sprintf("at least one of the flags in the group %s is required", keys.__str__())) + panic(fmt.sprintf("At least one of the flags in the group %s is required.", keys.__str__())) fn validate_mutually_exclusive_flag_group(data: Dict[String, Dict[String, Bool]]) -> None: @@ -129,7 +131,7 @@ fn validate_mutually_exclusive_flag_group(data: Dict[String, Dict[String, Bool]] panic( fmt.sprintf( - "if any flags in the group %s are set none of the others can be; %s were all set", + "If any flags in the group %s are set none of the others can be; %s were all set.", keys.__str__(), set.__str__(), ) diff --git a/src/prism/flag_parser.mojo b/src/prism/flag_parser.mojo index efe4b32..302cdbe 100644 --- a/src/prism/flag_parser.mojo +++ b/src/prism/flag_parser.mojo @@ -37,8 +37,11 @@ struct FlagParser: raise Error("Command does not accept the flag supplied: " + name) # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). - if flags.get_bool(name): + try: + _ = flags.lookup(name, "Bool") return name, String("True"), 1 + except: + pass if self.index + 1 >= len(arguments): raise Error("Flag `" + name + "` requires a value to be set but reached the end of arguments.") @@ -67,8 +70,7 @@ struct FlagParser: var flag = split(argument, "=") var shorthand = flag[0][1:] var value = flag[1] - var name = flags.lookup_name(shorthand).value() - + var name = flags.lookup_name(shorthand) if name not in flags.names(): raise Error("Command does not accept the shorthand flag supplied: " + name) @@ -76,10 +78,9 @@ struct FlagParser: # Flag with value set like "-f " var shorthand = argument[1:] - var result = flags.lookup_name(shorthand) - if not result: + var name = flags.lookup_name(shorthand) + if name not in flags.names(): raise Error("Command does not accept the shorthand flag supplied: " + shorthand) - var name = result.value() # If it's a bool flag, set it to True and only increment the index by 1 (one arg used). if flags.get_bool(name): @@ -126,11 +127,7 @@ struct FlagParser: raise Error("Expected a flag but found: " + argument) # Set the value of the flag. - var flag = flags.lookup(name) - if not flag: - raise Error("No flag found with the name: " + name) - - flag.value()[].set(value) + flags.lookup(name).set(value) self.index += increment_by return remaining_args diff --git a/src/prism/flag_set.mojo b/src/prism/flag_set.mojo index 203f2c9..199dd73 100644 --- a/src/prism/flag_set.mojo +++ b/src/prism/flag_set.mojo @@ -4,22 +4,7 @@ import gojo.fmt from .flag import Flag from .util import panic, string_to_bool, string_to_float, split from .flag_parser import FlagParser -from .transform import ( - get_string, - get_bool, - get_int, - get_int8, - get_int16, - get_int32, - get_int64, - get_uint8, - get_uint16, - get_uint32, - get_uint64, - as_float16, - as_float32, - as_float64, -) + alias FlagVisitorFn = fn (Flag) capturing -> None """Function perform some action while visiting all flags.""" @@ -34,7 +19,23 @@ alias REQUIRED_AS_GROUP = "REQUIRED_AS_GROUP" alias ONE_REQUIRED = "ONE_REQUIRED" alias MUTUALLY_EXCLUSIVE = "MUTUALLY_EXCLUSIVE" -alias FLAG_TYPES = ["String", "Bool", "Int", "Int8", "Int16", "Int32", "Int64", "UInt8", "UInt16", "UInt32", "UInt64", "Float16", "Float32", "Float64"] +alias FLAG_TYPES = [ + "String", + "Bool", + "Int", + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Float16", + "Float32", + "Float64", +] + @value struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparable): @@ -84,9 +85,7 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl fn __iadd__(inout self, other: Self): self.merge(other) - fn lookup( - ref [_] self, name: String, type: String = "" - ) -> Optional[Reference[Flag, __lifetime_of(self.flags)]]: + fn lookup(ref [_]self, name: String, type: String = "") raises -> ref [__lifetime_of(self.flags)] Flag: """Returns an mutable or immutable reference to a Flag with the given name. Mutable if FlagSet is mutable, immutable if FlagSet is immutable. @@ -100,14 +99,15 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl if type == "": for i in range(len(self.flags)): if self.flags[i].name == name: - return Reference(self.flags[i]) - else: + return self.flags[i] + else: for i in range(len(self.flags)): if self.flags[i].name == name and self.flags[i].type == type: - return Reference(self.flags[i]) - return None - - fn lookup_name(self, shorthand: String) -> Optional[String]: + return self.flags[i] + + raise Error("FlagNotFoundError: Could not find the following flag: " + name) + + fn lookup_name(self, shorthand: String) raises -> String: """Returns the name of a flag given its shorthand. Args: @@ -116,68 +116,69 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl for flag in self.flags: if flag[].shorthand and flag[].shorthand == shorthand: return flag[].name - return None + + raise Error("FlagNotFoundError: Could not find the following flag shorthand: " + shorthand) fn get_as[ - R: CollectionElement, transform: fn (flag_set: FlagSet, name: String) -> Optional[R] - ](self, name: String) -> Optional[R]: + R: CollectionElement, transform: fn (flag_set: FlagSet, name: String) raises -> R + ](self, name: String) raises -> R: return transform(self, name) - fn get_string(self, name: String) -> Optional[String]: + fn get_string(self, name: String) raises -> String: """Returns the value of a flag as a String. If it isn't set, then return the default value.""" - return self.get_as[R=String, transform=get_string](name) + return self.lookup(name, "String").value_or_default() - fn get_bool(self, name: String) -> Optional[Bool]: + fn get_bool(self, name: String) raises -> Bool: """Returns the value of a flag as a Bool. If it isn't set, then return the default value.""" - return self.get_as[R=Bool, transform=get_bool](name) + return string_to_bool(self.lookup(name, "Bool").value_or_default()) - fn get_int(self, name: String) -> Optional[Int]: + fn get_int(self, name: String, type: String = "Int") raises -> Int: """Returns the value of a flag as an Int. If it isn't set, then return the default value.""" - return self.get_as[R=Int, transform=get_int](name) + return atol(self.lookup(name, type).value_or_default()) - fn get_int8(self, name: String) -> Optional[Int8]: + fn get_int8(self, name: String) raises -> Int8: """Returns the value of a flag as a Int8. If it isn't set, then return the default value.""" - return self.get_as[R=Int8, transform=get_int8](name) + return Int8(self.get_int(name, "Int8")) - fn get_int16(self, name: String) -> Optional[Int16]: + fn get_int16(self, name: String) raises -> Int16: """Returns the value of a flag as a Int16. If it isn't set, then return the default value.""" - return self.get_as[R=Int16, transform=get_int16](name) + return Int16(self.get_int(name, "Int16")) - fn get_int32(self, name: String) -> Optional[Int32]: + fn get_int32(self, name: String) raises -> Int32: """Returns the value of a flag as a Int32. If it isn't set, then return the default value.""" - return self.get_as[R=Int32, transform=get_int32](name) + return Int32(self.get_int(name, "Int32")) - fn get_int64(self, name: String) -> Optional[Int64]: + fn get_int64(self, name: String) raises -> Int64: """Returns the value of a flag as a Int64. If it isn't set, then return the default value.""" - return self.get_as[R=Int64, transform=get_int64](name) + return Int64(self.get_int(name, "Int64")) - fn get_uint8(self, name: String) -> Optional[UInt8]: + fn get_uint8(self, name: String) raises -> UInt8: """Returns the value of a flag as a UInt8. If it isn't set, then return the default value.""" - return self.get_as[R=UInt8, transform=get_uint8](name) + return UInt8(self.get_int(name, "UInt8")) - fn get_uint16(self, name: String) -> Optional[UInt16]: + fn get_uint16(self, name: String) raises -> UInt16: """Returns the value of a flag as a UInt16. If it isn't set, then return the default value.""" - return self.get_as[R=UInt16, transform=get_uint16](name) + return UInt16(self.get_int(name, "UInt16")) - fn get_uint32(self, name: String) -> Optional[UInt32]: + fn get_uint32(self, name: String) raises -> UInt32: """Returns the value of a flag as a UInt32. If it isn't set, then return the default value.""" - return self.get_as[R=UInt32, transform=get_uint32](name) + return UInt32(self.get_int(name, "UInt32")) - fn get_uint64(self, name: String) -> Optional[UInt64]: + fn get_uint64(self, name: String) raises -> UInt64: """Returns the value of a flag as a UInt64. If it isn't set, then return the default value.""" - return self.get_as[R=UInt64, transform=get_uint64](name) + return UInt64(self.get_int(name, "UInt64")) - fn as_float16(self, name: String) -> Optional[Float16]: + fn as_float16(self, name: String) raises -> Float16: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float16, transform=as_float16](name) + return self.as_float64(name).cast[DType.float16]() - fn as_float32(self, name: String) -> Optional[Float32]: + fn as_float32(self, name: String) raises -> Float32: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float32, transform=as_float32](name) + return self.as_float64(name).cast[DType.float32]() - fn as_float64(self, name: String) -> Optional[Float64]: + fn as_float64(self, name: String) raises -> Float64: """Returns the value of a flag as a Float64. If it isn't set, then return the default value.""" - return self.get_as[R=Float64, transform=as_float64](name) + return string_to_float(self.lookup(name, "Float64").value_or_default()) fn names(self) -> List[String]: """Returns a list of names of all flags in the flag set.""" @@ -202,7 +203,7 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl default: Bool = False, ) -> None: """Adds a `Bool` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Bool")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Bool")) fn string_flag( inout self, @@ -212,55 +213,55 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl default: String = "", ) -> None: """Adds a `String` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="String")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="String")) fn int_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int = 0) -> None: """Adds an `Int` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int")) fn int8_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int8 = 0) -> None: """Adds an `Int8` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int8")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int8")) fn int16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int16 = 0) -> None: """Adds an `Int16` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int16")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int16")) fn int32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int32 = 0) -> None: """Adds an `Int32` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int32")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int32")) fn int64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Int64 = 0) -> None: """Adds an `Int64` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Int64")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Int64")) fn uint8_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt8 = 0) -> None: """Adds a `UInt8` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt8")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt8")) fn uint16_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt16 = 0) -> None: """Adds a `UInt16` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt16")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt16")) fn uint32_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt32 = 0) -> None: """Adds a `UInt32` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt32")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt32")) fn uint64_flag(inout self, name: String, usage: String, shorthand: String = "", default: UInt64 = 0) -> None: """Adds a `UInt64` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="UInt64")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="UInt64")) fn float16_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float16 = 0) -> None: """Adds a `Float16` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float16")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Float16")) fn float32_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float32 = 0) -> None: """Adds a `Float32` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float32")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Float32")) fn float64_flag(inout self, name: String, usage: String, shorthand: String = "", default: Float64 = 0) -> None: """Adds a `Float64` flag to the flag set.""" - self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, value=None, default=str(default), type="Float64")) + self.flags.append(Flag(name=name, shorthand=shorthand, usage=usage, default=str(default), type="Float64")) fn set_annotation(inout self, name: String, key: String, values: String) raises -> None: """Sets an annotation for a flag. @@ -270,17 +271,15 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl key: The key of the annotation. values: The values of the annotation. """ - var result = self.lookup(name) - if not result: - raise Error(String("FlagSet.set_annotation: Failed to find flag: {}.").format(name)) - # Annotation value can be a concatenated string of values. # Why? Because we can have multiple required groups of flags for example. # So each value of the list for the annotation can be a group of flag names. - if not result.value()[].annotations.get(key): - result.value()[].annotations[key] = List[String](values) - else: - result.value()[].annotations[key].extend(values) + try: + # TODO: remove running 2 lookups when ref can return a reference + # we can store as a var without copying the result. + self.lookup(name).annotations[key].extend(values) + except: + self.lookup(name).annotations[key] = List[String](values) fn set_required(inout self, name: String) raises -> None: """Sets a flag as required or not. @@ -303,9 +302,9 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl self.set_annotation(name, annotation_type, names) except e: print( - String( - "FlagSet.set_as: Failed to set flag, {}, with the following annotation: {}" - ).format(name, annotation_type), + String("FlagSet.set_as: Failed to set flag, {}, with the following annotation: {}").format( + name, annotation_type + ), file=2, ) raise e @@ -337,8 +336,11 @@ struct FlagSet(CollectionElement, Stringable, Sized, Boolable, EqualityComparabl @always_inline fn add_flag(flag: Flag) capturing -> None: - if not self.lookup(flag.name): - self.flags.append(flag) + try: + _ = self.lookup(flag.name) + except e: + if str(e).find("FlagNotFoundError") != -1: + self.flags.append(flag) new_set.visit_all[add_flag]() @@ -369,4 +371,4 @@ fn validate_required_flags(flags: FlagSet) raises -> None: flags.visit_all[check_required_flag]() if len(missing_flag_names) > 0: - raise Error("required flag(s) " + missing_flag_names.__str__() + " not set") + raise Error("Required flag(s): " + missing_flag_names.__str__() + " not set.") diff --git a/src/prism/transform.mojo b/src/prism/transform.mojo deleted file mode 100644 index 3b1ff30..0000000 --- a/src/prism/transform.mojo +++ /dev/null @@ -1,226 +0,0 @@ -from collections import Optional -from .flag_set import FlagSet -from .util import panic, string_to_bool, string_to_float - - -fn get_string(flag_set: FlagSet, name: String) -> Optional[String]: - """Returns the value of a flag as a String. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var flag = flag_set.lookup(name, "String") - if not flag: - return None - - fn get(value: String) -> String: - return value - - # TODO: inferring the return type in the parameter only works for String as of 24.5. - # Will switch the other transform functions in the future when it works. - return flag.value()[].get_with_transform[get]() - - -fn get_bool(flag_set: FlagSet, name: String) -> Optional[Bool]: - """Returns the value of a flag as a Bool. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var result = flag_set.lookup(name, "Bool") - if not result: - return None - - var flag = result.value() - if not flag[].value: - return string_to_bool(flag[].default) - - return string_to_bool(flag[].value.value()) - - -fn get_int(flag_set: FlagSet, name: String) -> Optional[Int]: - """Returns the value of a flag as an Int. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var result = flag_set.lookup(name, "Int") - if not result: - return None - - var flag = result.value() - - # TODO: I don't like this swallowing up a failure to convert to int. Maybe return a tuple of optional and error? - try: - if not flag[].value: - return atol(flag[].default) - - return atol(flag[].value.value()) - except: - return None - - -fn get_int8(flag_set: FlagSet, name: String) -> Optional[Int8]: - """Returns the value of a flag as a Int8. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int8(value.value()) - - -fn get_int16(flag_set: FlagSet, name: String) -> Optional[Int16]: - """Returns the value of a flag as a Int16. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int16(value.value()) - - -fn get_int32(flag_set: FlagSet, name: String) -> Optional[Int32]: - """Returns the value of a flag as a Int32. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int32(value.value()) - - -fn get_int64(flag_set: FlagSet, name: String) -> Optional[Int64]: - """Returns the value of a flag as a Int64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return Int64(value.value()) - - -fn get_uint8(flag_set: FlagSet, name: String) -> Optional[UInt8]: - """Returns the value of a flag as a UInt8. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt8(value.value()) - - -fn get_uint16(flag_set: FlagSet, name: String) -> Optional[UInt16]: - """Returns the value of a flag as a UInt16. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt16(value.value()) - - -fn get_uint32(flag_set: FlagSet, name: String) -> Optional[UInt32]: - """Returns the value of a flag as a UInt32. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt32(value.value()) - - -fn get_uint64(flag_set: FlagSet, name: String) -> Optional[UInt64]: - """Returns the value of a flag as a UInt64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.get_int(name) - if not value: - return None - - return UInt64(value.value()) - - -fn as_float16(flag_set: FlagSet, name: String) -> Optional[Float16]: - """Returns the value of a flag as a Float64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.as_float64(name) - if not value: - return None - - return value.value().cast[DType.float16]() - - -fn as_float32(flag_set: FlagSet, name: String) -> Optional[Float32]: - """Returns the value of a flag as a Float64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var value = flag_set.as_float64(name) - if not value: - return None - - return value.value().cast[DType.float32]() - - -fn as_float64(flag_set: FlagSet, name: String) -> Optional[Float64]: - """Returns the value of a flag as a Float64. If it isn't set, then return the default value. - - Args: - flag_set: The FlagSet to get the value from. - name: The name of the flag to return. - """ - var result = flag_set.lookup(name, "Float64") - if not result: - return None - - var flag = result.value() - - # TODO: I don't like this swallowing up a failure to convert to int. Maybe return a tuple of optional and error? - try: - if not flag[].value: - return string_to_float(flag[].default) - - return string_to_float(flag[].value.value()) - except e: - return None diff --git a/test/test_args.mojo b/test/test_args.mojo index 84e8bd5..037c9b2 100644 --- a/test/test_args.mojo +++ b/test/test_args.mojo @@ -10,7 +10,7 @@ from prism.args import ( exact_args, range_args, match_all, - ArgValidator, + ArgValidatorFn, ) @@ -20,20 +20,20 @@ from prism.args import ( # TODO: renable these when we have assert raises in testing # def test_no_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var arc = Arc(cmd) # var result = no_args(arc, List[String]("abc")) # testing.assert_equal(result.value(), String("The command `root` does not take any arguments.")) # def test_valid_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = valid_args[List[String]("Pineapple")]()(Arc(cmd), List[String]("abc")) # testing.assert_equal(result.value(), "Invalid argument: `abc`, for the command `root`.") # def test_arbitrary_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = arbitrary_args(Arc(cmd), List[String]("abc", "blah", "blah")) # # If the result is anything but None, fail the test. @@ -42,33 +42,33 @@ from prism.args import ( # def test_minimum_n_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = minimum_n_args[3]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root` accepts at least 3 argument(s). Received: 2.") # def test_maximum_n_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = maximum_n_args[1]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root` accepts at most 1 argument(s). Received: 2.") # def test_exact_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = exact_args[1]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root` accepts exactly 1 argument(s). Received: 2.") # def test_range_args(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var result = range_args[0, 1]()(Arc(cmd), List[String]("abc", "123")) # testing.assert_equal(result.value(), "The command `root`, accepts between 0 to 1 argument(s). Received: 2.") # def test_match_all(): -# var cmd = Command(name="root", description="Base command.", run=dummy) +# var cmd = Command(name="root", usage="Base command.", run=dummy) # var args = List[String]("abc", "123") -# alias validators = List[ArgValidator]( +# alias validators = List[ArgValidatorFn]( # range_args[0, 1](), # valid_args[List[String]("Pineapple")]() # ) diff --git a/test/test_command.mojo b/test/test_command.mojo index 48f2ea5..428e927 100644 --- a/test/test_command.mojo +++ b/test/test_command.mojo @@ -8,15 +8,15 @@ def test_command_operations(): fn dummy(ctx: Context) -> None: return None - var cmd = Arc(Command(name="root", description="Base command.", run=dummy)) + var cmd = Arc(Command(name="root", usage="Base command.", run=dummy)) var flags = cmd[].flags.flags for flag in flags: testing.assert_equal(String("help"), flag[].name) - var child_cmd = Arc(Command(name="child", description="Child command.", run=dummy)) + var child_cmd = Arc(Command(name="child", usage="Child command.", run=dummy)) cmd[].add_subcommand(child_cmd) - child_cmd[].flags.string_flag(name="color", shorthand="c", usage="Text color", default="#3464eb") + child_cmd[].flags.get_string(name="color", shorthand="c", usage="Text color", default="#3464eb") testing.assert_equal(child_cmd[].full_name(), "root child") diff --git a/test/test_flags.mojo b/test/test_flags.mojo index 64406b6..68fe009 100644 --- a/test/test_flags.mojo +++ b/test/test_flags.mojo @@ -20,8 +20,8 @@ def test_string_to_float(): def test_get_flags(): var flag_set = FlagSet() - flag_set.string_flag("key", "description", "default") - flag_set.bool_flag("flag", "description", "False") + flag_set.get_string("key", "usage", "default") + flag_set.get_bool("flag", "usage", "False") var flags = List[String]("--key=value", "positional", "--flag") _ = flag_set.from_args(flags) @@ -31,8 +31,8 @@ def test_get_flags(): def test_parse_flag(): var flag_set = FlagSet() - flag_set.string_flag(name="key", usage="description", default="default") - flag_set.bool_flag(name="flag", usage="description", default=False) + flag_set.get_string(name="key", usage="usage", default="default") + flag_set.get_bool(name="flag", usage="usage", default=False) var parser = FlagParser() var name: String @@ -51,8 +51,8 @@ def test_parse_flag(): def test_parse_shorthand_flag(): var flag_set = FlagSet() - flag_set.string_flag(name="key", usage="description", default="default", shorthand="k") - flag_set.bool_flag(name="flag", usage="description", default=False, shorthand="f") + flag_set.get_string(name="key", usage="usage", default="default", shorthand="k") + flag_set.get_bool(name="flag", usage="usage", default=False, shorthand="f") var parser = FlagParser() var name: String