Skip to content

Commit

Permalink
Fix pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 13, 2023
1 parent b7a62c0 commit cc1d8c5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/viser/infra/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any:


@functools.lru_cache(maxsize=None)
def get_type_hints_cached(cls: Type) -> Dict[str, Any]:
def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]:
return get_type_hints(cls)


Expand Down
40 changes: 20 additions & 20 deletions src/viser/infra/_typescript_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,23 @@
}


def _get_ts_type(typ: Type) -> str:
if get_origin(typ) is tuple:
def _get_ts_type(typ: Type[Any]) -> str:
origin_typ = get_origin(typ)

if origin_typ is tuple:
args = get_args(typ)
if len(args) == 2 and args[1] == ...:
return _get_ts_type(args[0]) + "[]"
else:
return "[" + ", ".join(map(_get_ts_type, args)) + "]"
if get_origin(typ) in (Literal, LiteralAlt):
elif origin_typ in (Literal, LiteralAlt):
return " | ".join(
map(
lambda lit: repr(lit).lower() if type(lit) is bool else repr(lit),
get_args(typ),
)
)
if is_typeddict(typ):
hints = get_type_hints(typ)

def fmt(key):
val = hints[key]
ret = f"'{key}'" + ": " + _get_ts_type(val)
return ret

ret = "{" + ", ".join(map(fmt, hints)) + "}"
# ret = "{" + f"type: \'{typ.__name__}\', " + ", ".join(map(fmt, hints)) + "}"
return ret
if get_origin(typ) is Union:
elif origin_typ is Union:
return (
"("
+ " | ".join(
Expand All @@ -61,13 +52,22 @@ def fmt(key):
)
+ ")"
)
elif is_typeddict(typ):
hints = get_type_hints(typ)

if hasattr(typ, "__origin__"):
typ = typ.__origin__
if typ in _raw_type_mapping:
return _raw_type_mapping[typ]
def fmt(key):
val = hints[key]
ret = f"'{key}'" + ": " + _get_ts_type(val)
return ret

assert False, f"Unsupported type: {typ}"
ret = "{" + ", ".join(map(fmt, hints)) + "}"
return ret
else:
# Like get_origin(), but also supports numpy.typing.NDArray[dtype].
typ = getattr(typ, "__origin__", typ)

assert typ in _raw_type_mapping, f"Unsupported type {typ}"
return _raw_type_mapping[typ]


def generate_typescript_interfaces(message_cls: Type[Message]) -> str:
Expand Down

0 comments on commit cc1d8c5

Please sign in to comment.