Skip to content

Commit

Permalink
better handling of optional args
Browse files Browse the repository at this point in the history
  • Loading branch information
Zomatree committed May 28, 2024
1 parent bc8d650 commit 48e0f0f
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions revolt/ext/commands/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Coroutine,
Generic, Literal, Optional, Union, get_args, get_origin)
from typing_extensions import ParamSpec
import sys

from revolt.utils import maybe_coroutine
if sys.version_info >= (3, 10):
from types import UnionType

UnionTypes = (Union, UnionType)
else:
UnionTypes = Union

from ...utils import maybe_coroutine

from .errors import CommandOnCooldown, InvalidLiteralArgument, UnionConverterError
from .utils import ClientT_Co_D, evaluate_parameters, ClientT_Co
Expand Down Expand Up @@ -126,7 +134,7 @@ async def _default_error_handler(self, ctx: Context[ClientT_Co_D], error: Except

@classmethod
async def handle_origin(cls, context: Context[ClientT_Co_D], origin: Any, annotation: Any, arg: str) -> Any:
if origin is Union:
if origin in UnionTypes:
for converter in get_args(annotation):
try:
return await cls.convert_argument(arg, converter, context)
Expand Down Expand Up @@ -175,6 +183,10 @@ async def parse_arguments(self, context: Context[ClientT_Co_D]) -> None:
except StopIteration:
if parameter.default is not parameter.empty:
arg = parameter.default

elif is_optional(parameter.annotation):
arg = None

else:
raise

Expand All @@ -192,7 +204,10 @@ async def parse_arguments(self, context: Context[ClientT_Co_D]) -> None:
except StopIteration:
if parameter.default is not parameter.empty:
arg = parameter.default
context.view.undo()

elif is_optional(parameter.annotation):
arg = None

else:
raise

Expand Down Expand Up @@ -251,6 +266,9 @@ def get_usage(self) -> str:

return f"{' '.join(parents[::-1])} {self.name} {' '.join(parameters)}"

def is_optional(arg: Any) -> bool:
return get_origin(arg) in UnionTypes and any(arg is NoneType for arg in get_args(arg))

def command(
*,
name: Optional[str] = None,
Expand Down

0 comments on commit 48e0f0f

Please sign in to comment.