From cc1d8c56127c5fc09ef04d99937c1dc5dc8474ae Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Mon, 13 Nov 2023 00:56:04 -0800 Subject: [PATCH] Fix pyright errors --- src/viser/infra/_messages.py | 2 +- src/viser/infra/_typescript_interface_gen.py | 40 ++++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index b8f48d4c5..ea6e68677 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -66,7 +66,7 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: @functools.lru_cache(maxsize=None) -def get_type_hints_cached(cls: Type) -> Dict[str, Any]: +def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]: return get_type_hints(cls) diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 9c82a327b..961ac90ac 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -25,32 +25,23 @@ } -def _get_ts_type(typ: Type) -> str: - if get_origin(typ) is tuple: +def _get_ts_type(typ: Type[Any]) -> str: + origin_typ = get_origin(typ) + + if origin_typ is tuple: args = get_args(typ) if len(args) == 2 and args[1] == ...: return _get_ts_type(args[0]) + "[]" else: return "[" + ", ".join(map(_get_ts_type, args)) + "]" - if get_origin(typ) in (Literal, LiteralAlt): + elif origin_typ in (Literal, LiteralAlt): return " | ".join( map( lambda lit: repr(lit).lower() if type(lit) is bool else repr(lit), get_args(typ), ) ) - if is_typeddict(typ): - hints = get_type_hints(typ) - - def fmt(key): - val = hints[key] - ret = f"'{key}'" + ": " + _get_ts_type(val) - return ret - - ret = "{" + ", ".join(map(fmt, hints)) + "}" - # ret = "{" + f"type: \'{typ.__name__}\', " + ", ".join(map(fmt, hints)) + "}" - return ret - if get_origin(typ) is Union: + elif origin_typ is Union: return ( "(" + " | ".join( @@ -61,13 +52,22 @@ def fmt(key): ) + ")" ) + elif is_typeddict(typ): + hints = get_type_hints(typ) - if hasattr(typ, "__origin__"): - typ = typ.__origin__ - if typ in _raw_type_mapping: - return _raw_type_mapping[typ] + def fmt(key): + val = hints[key] + ret = f"'{key}'" + ": " + _get_ts_type(val) + return ret - assert False, f"Unsupported type: {typ}" + ret = "{" + ", ".join(map(fmt, hints)) + "}" + return ret + else: + # Like get_origin(), but also supports numpy.typing.NDArray[dtype]. + typ = getattr(typ, "__origin__", typ) + + assert typ in _raw_type_mapping, f"Unsupported type {typ}" + return _raw_type_mapping[typ] def generate_typescript_interfaces(message_cls: Type[Message]) -> str: