-
Notifications
You must be signed in to change notification settings - Fork 1
/
actions.py
348 lines (283 loc) · 13.8 KB
/
actions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# This files contains your custom actions which can be used to run
# custom Python code.
#
# See this guide on how to implement these action:
# https://rasa.com/docs/rasa/core/actions/#custom-actions/
import logging
import time
from typing import Any, Dict, List, Text, Union
from rasa_sdk import Action, Tracker
from rasa_sdk.events import AllSlotsReset, UserUttered, SlotSet, ActionExecuted, EventType, FollowupAction, Restarted, SessionStarted, SlotSet
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.forms import Action, FormAction, REQUESTED_SLOT
from rasa.core.slots import Slot
from rasa.core.events import Event
from synonym_extraction import collect_synonym, add_synonym
from slots import valid_placements, update_known_colors, update_known_objects
from train import train_model
logger = logging.getLogger(__name__)
try:
from ros_comm import nlp_node
ENABLE_ROS = True
except Exception as e:
logger.warning("Failed to load ros_comm module: {}".format(e))
ENABLE_ROS = False
input_nlu_file = './data/nlu.md'
user_nlu_file = './data/user_nlu.md'
input_nlu_file = './data/nlu/synonyms.md'
user_nlu_file = './data/nlu/user_nlu.md'
class FillActionSlot(Action):
"""Fills the action slot when a message is received."""
def name(self) -> Text: return "action_fill"
def run(self, dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
# dispatcher.utter_message(template="utter_received_command")
action = tracker.latest_message['intent'].get('name')
logger.info('Got action: {}'.format(action))
logger.debug(tracker.sender_id)
if action in ['find', 'pick up', 'move']:
return [SlotSet("action", action)]
elif action == 'show':
return [SlotSet("action", "learn")]
else:
return []
class ReceivedFind(Action):
def name(self) -> Text: return "execute_find"
def run(self, dispatcher, tracker, domain):
object_name = tracker.get_slot('object_name')
object_color = tracker.get_slot('object_color')
placement_origin = tracker.get_slot('placement')
#
# if placement_origin not in valid_placements:
# placement_origin = "any"
# dispatcher.utter_message(text="Hang on, I'll try to find a {} {} somewhere on the table".format(
# object_color,
# object_name
# ))
# else:
# dispatcher.utter_message(text="Hang on, I'll try to find a {} {} in the {} area of the table".format(
# object_color,
# object_name,
# placement_origin
# ))
if ENABLE_ROS:
nlp_node.send_command("find", object_name, object_color, placement_origin)
response = nlp_node.wait_for_response()
try:
msg, path_2dimg, path_3dimg = response
except Exception:
msg, path_2dimg = response
if msg is not None:
# dispatcher.utter_message(template="utter_executed_command")
# handle 2d vision response
print("Found {} object: {}".format(msg.desired_color, msg.found_obj))
imgurl_2d = "http://localhost:8888/{}?time={}".format(path_2dimg, int(time.time()))
dispatcher.utter_attachment(None, image=imgurl_2d)
if msg.found_obj:
if placement_origin in valid_placements:
dispatcher.utter_message(text="I found the {} object you asked for in the {} area.".format(
msg.desired_color,
placement_origin
))
else:
dispatcher.utter_message(text="I found the {} object you asked for.".format(
msg.desired_color
))
else:
if placement_origin in valid_placements:
dispatcher.utter_message(text="I didn't find anything {} in the {} area. This is what I can see".format(
msg.desired_color,
placement_origin
))
else:
dispatcher.utter_message(text="I didn't find anything {}. This is what I can see.".format(
msg.desired_color
))
# handle 3D Vision response
try:
imgurl_3d = "http://localhost:8888/{}?time={}".format(path_3dimg, int(time.time()))
if msg.pcl_obj:
dispatcher.utter_attachment(None, image=imgurl_3d)
dispatcher.utter_message(text="I found the {} you asked for.".format(
msg.pcl_object
))
# else:
# dispatcher.utter_message(text="I didn't find any {}.".format(
# object_name
# ))
except Exception as e:
logger.warning(e)
else:
dispatcher.utter_message(template="utter_command_failed")
return [AllSlotsReset()]
# dispatcher.utter_message(text="Error: {}...Check that the required ROS Service is running!".format(info))
else:
dispatcher.utter_message(text="I found the {} {} you asked for.".format(
object_color,
object_name
))
return [AllSlotsReset()]
class ReceivedLearn(Action):
def name(self) -> Text: return "execute_learn"
def run(self, dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
object_name = tracker.get_slot('object_name')
object_color = tracker.get_slot('object_color')
placement = tracker.get_slot('placement')
if placement in valid_placements:
placement_origin = placement
else:
placement_origin="middle"
#
# dispatcher.utter_message(text="Hang on, I'll try to search in the {} area of the table for the object you want me to learn".format(
# placement_origin
# ))
if ENABLE_ROS:
nlp_node.send_command(action="show",
object=object_name,
obj_color=object_color,
placement_origin=placement_origin,
placement_destination=None)
response = nlp_node.wait_for_response()
try:
msg, path_2dimg, path_3dimg = response
except Exception:
msg, path_2dimg = response
if msg is not None:
imgpath = path_2dimg
print("Image saved at {}".format(imgpath))
print("Found object: {}".format(msg.desired_color, msg.found_obj))
imgurl = "http://localhost:8888/{}?time={}".format(imgpath,int(time.time()))
dispatcher.utter_attachment(None, image=imgurl)
if msg.found_obj:
# dispatcher.utter_message(text="I found the {} {} in the {} area of the platform.".format(
# msg.desired_color,
# object_name,
# placement_origin))
dispatcher.utter_message(template="utter_got_description")
update_known_objects([object_name])
update_known_colors([object_color])
else:
dispatcher.utter_message(text="Sorry, I didn't find any object. Make sure the {} {} you want to show me is in the {} area of the platform.".format(
msg.desired_color,
object_name,
placement_origin))
else:
dispatcher.utter_message(template="utter_command_failed")
# dispatcher.utter_message(text="Error: {}".format(info))
return [AllSlotsReset()]
else:
dispatcher.utter_message(template="utter_got_description")
return [AllSlotsReset()]
class ReceivedPickup(Action):
def name(self) -> Text: return "execute_pickup"
def run(self, dispatcher, tracker, domain):
object_name = tracker.get_slot('object_name')
object_color = tracker.get_slot('object_color')
placement_origin = tracker.get_slot('placement')
# if placement_origin not in valid_placements:
# placement_origin = "any"
# dispatcher.utter_message(text="Hang on, I'll try to pick up the {} {} somewhere on the table".format(
# object_color,
# object_name
# ))
# else:
# dispatcher.utter_message(text="Hang on, I'll try to pick up the {} {} in the {} area of the table".format(
# object_color,
# object_name,
# placement_origin
# ))
if ENABLE_ROS:
nlp_node.send_command("pick up", object_name, object_color, placement_origin)
response = nlp_node.wait_for_response()
try:
msg, path_2dimg, _ = response
except Exception:
msg, path_2dimg = response
if msg is not None:
# dispatcher.utter_message(template="utter_executed_command")
if path_2dimg is not None:
imgpath = path_2dimg
print("Image saved at {}".format(imgpath))
print("Found {} object: {}".format(msg.desired_color, msg.found_obj))
imgurl = "http://localhost:8888/{}?time={}".format(imgpath, int(time.time()))
dispatcher.utter_attachment(None, image=imgurl)
# dispatcher.utter_message(text="Got response code {} from gripper.".format(msg.grippercode))
if msg.grippercode in [1,2,3]:
dispatcher.utter_message(template="utter_command_failed")
else:
dispatcher.utter_message(text="I've managed to pick it up!")
else:
dispatcher.utter_message(template="utter_command_failed")
return [AllSlotsReset()]
# dispatcher.utter_message(text="Error: {}...Check that the required ROS Service is running!".format(info))
else:
dispatcher.utter_message(text="I've managed to pick it up!")
return [AllSlotsReset()]
class ReceivedMove(Action):
def name(self) -> Text: return "execute_move"
def run(self, dispatcher, tracker, domain):
object_name = tracker.get_slot('object_name')
object_color = tracker.get_slot('object_color')
placement_destination = tracker.get_slot('placement')
#
# dispatcher.utter_message(text="Hang on, I'll try to move the {} {} to the {}".format(
# object_color,
# object_name,
# placement_destination
# ))
if ENABLE_ROS:
nlp_node.send_command("move", object_name, object_color, placement_destination=placement_destination, placement_origin="any")
response = nlp_node.wait_for_response()
try:
msg, path_2dimg, _ = response
except Exception:
msg, path_2dimg = response
if msg is not None:
# dispatcher.utter_message(template="utter_executed_command")
if path_2dimg is not None:
imgpath = path_2dimg
print("Image saved at {}".format(imgpath))
print("Found {} object: {}".format(msg.desired_color, msg.found_obj))
imgurl = "http://localhost:8888/{}?time={}".format(imgpath, int(time.time()))
dispatcher.utter_attachment(None, image=imgurl)
# dispatcher.utter_message(text="Got response code {} from gripper.".format(msg.grippercode))
if msg.grippercode in [1,2,3]:
dispatcher.utter_message(template="utter_command_failed")
else:
dispatcher.utter_message(text="I've managed to move it!")
else:
dispatcher.utter_message(template="utter_command_failed")
return [AllSlotsReset()]
# dispatcher.utter_message(text="Error: {}...Check that the required ROS Service is running!".format(info))
else:
dispatcher.utter_message(text="I've managed to move it!")
return [AllSlotsReset()]
class ReceivedCancel(Action):
"""Reset all slots if a command was denied."""
def name(self) -> Text: return "action_cancel"
def run(self, dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
return [AllSlotsReset()]
class FallbackAction(Action):
"""This action is triggered in case of uncertain/ambiguous predictions."""
def name(self) -> Text: return "action_fallback"
def run(self, dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
dispatcher.utter_message(
text="Sorry, I'm not that smart yet so I'm not sure what you want me to do. I can find an object, pick it up, or move it to a certain location.")
return []
class RetrainAction(Action):
"""This action is triggered in case of uncertain/ambiguous predictions."""
def name(self) -> Text: return "action_retrain"
def run(self, dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
dispatcher.utter_message(
text="Re-training model")
train_model()
return []