-
Notifications
You must be signed in to change notification settings - Fork 0
/
scope_env.py
164 lines (127 loc) · 5.45 KB
/
scope_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# This file is loaded with two different global scopes.
# The first scope is the one when this file is imported by another file. This will provide a list of available functions to be added to the LLM prompts.
# The second scope is created when this file is loaded and run by the exec() function. This will provide a valid scope to run the program generated by LLM.
# We can't have only one scope, because some predefined functions use global variables that are statically (lexically) binded.
import abc
import inspect
from typing import Callable
import numpy.typing as npt
from misc_utils import is_list_of_type
from scannet_utils import ObjInstance
# registry for predefined functions
AVAILABLE_HANDLERS: dict[str, object] = {}
class TargetInfo:
best_instance: ObjInstance | None = None
candidate_instances: set[ObjInstance] = set()
anchor_instances: dict[str, ObjInstance | list[ObjInstance]] = {}
csp_desc: str | None = None
llm_used: bool = False
@staticmethod
def reset():
__class__.best_instance = None
__class__.candidate_instances = set()
__class__.anchor_instances = {}
__class__.csp_desc = None
__class__.llm_used = False
class GlobalState:
relevant_obj_map: dict[int, str] = {}
relevant_obj_set: set[str] = set()
relevant_obj_instances: dict[str, list[ObjInstance]] = {}
room_center: npt.NDArray | None = (None,) # array of shape (3,)
room_corners: tuple[
npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray
] # four corners of the room (scene)
@staticmethod
def get_cand_insts(label: str | list[str]) -> list[ObjInstance]:
if isinstance(label, str):
return sorted(
__class__.relevant_obj_instances[label],
key=lambda x: x.inst_id,
)
elif is_list_of_type(label, str):
cand_insts = [set(__class__.relevant_obj_instances[lbl]) for lbl in label]
return sorted(
list(set.union(*cand_insts)),
key=lambda x: x.inst_id,
)
raise SystemError(f"invalid argument: {label}")
def register_handler(disabled: bool = False):
def register_handler_empty(handler_class):
return handler_class
def register_handler_func(handler_class):
assert hasattr(handler_class, "FUNC_NAME")
assert hasattr(handler_class, "SIG_STR")
assert hasattr(handler_class, "call_type_check")
assert hasattr(handler_class, "call")
func_name = handler_class.FUNC_NAME
if isinstance(func_name, str):
AVAILABLE_HANDLERS[func_name] = handler_class()
elif isinstance(func_name, list):
assert all([isinstance(x, str) for x in func_name])
for name in func_name:
AVAILABLE_HANDLERS[name] = handler_class()
else:
raise SystemError(f"invalid handler class: {handler_class}!")
return handler_class
if disabled:
return register_handler_empty
return register_handler_func
def get_predef_func_sigs():
funcs = []
for func_name, handler_class in AVAILABLE_HANDLERS.items():
sig_str = handler_class.SIG_STR
doc_str = None
if hasattr(handler_class, "DOC_STR"):
doc_str = handler_class.DOC_STR
funcs.append((func_name, sig_str, doc_str))
return sorted(funcs, key=lambda x: x[0])
def get_eval_scope(use_type_check_funcs) -> dict[str, Callable]:
"""get a dict with predefined functions. this can be used as the scope for exec()"""
def build_type_check_func(instance, name):
def f(*args, **kwargs):
ret = instance.call_type_check(*args, **kwargs)
if hasattr(ret, "apparent_name"):
ret.set_apparent_name(name)
return ret
return f
def build_func(instance, name):
def f(*args, **kwargs):
ret = instance.call(*args, **kwargs)
if hasattr(ret, "apparent_name"):
ret.set_apparent_name(name)
return ret
return f
func_dict = {}
for func_name, handler_instance in AVAILABLE_HANDLERS.items():
if use_type_check_funcs:
func_dict[func_name] = build_type_check_func(handler_instance, func_name)
else:
func_dict[func_name] = build_func(handler_instance, func_name)
return func_dict
def set_relevant_obj_map(relevant_obj_map):
"""set the object dict used in type checking phase"""
GlobalState.relevant_obj_map = relevant_obj_map
GlobalState.relevant_obj_set = set(relevant_obj_map.values())
def set_instance_map(instance_map):
"""set the instance map used in grounding phase"""
GlobalState.relevant_obj_instances = instance_map
def set_room_center(room_center):
"""set the room center used in grounding phase"""
GlobalState.room_center = room_center
def set_target_info(
best_instance: ObjInstance,
candidate_instances: set[ObjInstance],
anchor_instances: dict[str, ObjInstance | list[ObjInstance]],
csp_desc: str,
llm_used: bool,
):
assert isinstance(best_instance, ObjInstance)
assert isinstance(candidate_instances, (set, list))
assert isinstance(anchor_instances, dict)
assert isinstance(csp_desc, str)
assert isinstance(llm_used, bool)
TargetInfo.best_instance = best_instance
TargetInfo.candidate_instances = set(candidate_instances)
TargetInfo.anchor_instances = anchor_instances
TargetInfo.csp_desc = csp_desc
TargetInfo.llm_used = llm_used