Skip to content

Commit

Permalink
chat: function calls made consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
rmackay9 committed Jan 3, 2024
1 parent b67a067 commit 3839df3
Showing 1 changed file with 49 additions and 166 deletions.
215 changes: 49 additions & 166 deletions MAVProxy/modules/mavproxy_chat/chat_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,149 +192,50 @@ def handle_function_call(self, run):
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
# init output to None
output = "invalid function call"
recognised_function = False

# get current date and time
if tool_call.function.name == "get_current_datetime":
recognised_function = True
output = self.get_formatted_date()

# get vehicle type
if tool_call.function.name == "get_vehicle_type":
recognised_function = True
output = json.dumps(self.get_vehicle_type())

# get mode mapping
if tool_call.function.name == "get_mode_mapping":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.get_mode_mapping(arguments)
except:
output = tool_call.function.name + ": failed"
print("chat: " + output)

# get vehicle state including armed, mode
if tool_call.function.name == "get_vehicle_state":
recognised_function = True
output = json.dumps(self.get_vehicle_state())

# get vehicle location and yaw
if tool_call.function.name == "get_vehicle_location_and_yaw":
recognised_function = True
output = json.dumps(self.get_vehicle_location_and_yaw())

# get_location_plus_offset
if tool_call.function.name == "get_location_plus_offset":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
except:
print("chat::handle_function_call: get_location_plus_offset: failed to parse arguments")
output = "get_location_plus_offset: failed to parse arguments"
try:
output = json.dumps(self.get_location_plus_offset(arguments))
except:
print("chat::handle_function_call: get_location_plus_offset: failed to calc location")
output = "get_location_plus_offset: failed to get location"

# send mavlink command_int
if tool_call.function.name == "send_mavlink_command_int":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.send_mavlink_command_int(arguments)
except:
print("chat::handle_function_call: failed to parse arguments")

# send mavlink set_position_target_global_int
if tool_call.function.name == "send_mavlink_set_position_target_global_int":
recognised_function = True
# handle supported functions
supported_funcs = ["get_current_datetime",
"get_vehicle_type",
"get_mode_mapping",
"get_vehicle_state",
"get_vehicle_location_and_yaw",
"get_location_plus_offset",
"send_mavlink_command_int",
"send_mavlink_set_position_target_global_int", # got here
"get_available_mavlink_messages", "get_mavlink_message",
"get_all_parameters", # ok
"get_parameter", # ok
"set_parameter", # ok
"set_wakeup_timer",
"get_wakeup_timers",
"delete_wakeup_timers"]

# convert function name to a callable function
func_name = tool_call.function.name
func = getattr(self, func_name, None)

if func_name in supported_funcs and func is not None:
stage_str = None
try:
# parse arguments
stage_str = "parse arguments"
arguments = json.loads(tool_call.function.arguments)
output = self.send_mavlink_set_position_target_global_int(arguments)
except:
print("chat: send_mavlink_set_position_target_global_int: failed to parse arguments")


# get a list of mavlink message names that can be retrieved using the get_mavlink_message function
if tool_call.function.name == "get_available_mavlink_messages":
recognised_function = True
output = self.get_available_mavlink_messages()

# get mavlink message from vehicle
if tool_call.function.name == "get_mavlink_message":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.get_mavlink_message(arguments)
except:
output = "get_mavlink_message: failed to retrieve message"
print("chat: get_mavlink_message: failed to retrieve message")
# call function
stage_str = "function call"
output = func(arguments)

# get all parameters from vehicle
if tool_call.function.name == "get_all_parameters":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.get_all_parameters(arguments)
# convert to json
stage_str = "convert output to json"
output = json.dumps(output)
except:
output = "get_all_parameters: failed to retrieve parameters"
print("chat: get_all_parameters: failed to retrieve parameters")
error_message = str(func_name) + ": " + stage_str + " failed"
print("chat: " + error_message)
output = error_message

# get a vehicle parameter's value
if tool_call.function.name == "get_parameter":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.get_parameter(arguments)
except:
output = "get_parameter: failed to retrieve parameter value"
print("chat: get_parameters: failed to retrieve parameter value")

# set a vehicle parameter's value
if tool_call.function.name == "set_parameter":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.set_parameter(arguments)
except:
output = "set_parameter: failed to set parameter value"
print("chat: set_parameter: failed to set parameter value")

# set a wakeup timer
if tool_call.function.name == "set_wakeup_timer":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.set_wakeup_timer(arguments)
except:
output = tool_call.function.name + ": failed"
print("chat: " + output)

# get wakeup timers
if tool_call.function.name == "get_wakeup_timers":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.get_wakeup_timers(arguments)
except:
output = tool_call.function.name + ": failed"
print("chat: " + output)

# delete wakeup timers
if tool_call.function.name == "delete_wakeup_timers":
recognised_function = True
try:
arguments = json.loads(tool_call.function.arguments)
output = self.delete_wakeup_timers(arguments)
except:
output = tool_call.function.name + ": failed"
print("chat: " + output)

if not recognised_function:
print("chat: handle_function_call: unrecognised function call: " + tool_call.function.name)
output = "unrecognised function call: " + tool_call.function.name
else:
print("chat: unrecognised function name: " + func_name)
output = "unrecognised function call: " + func_name

# append output to list of outputs
tool_outputs.append({"tool_call_id": tool_call.id, "output": output})
Expand All @@ -351,11 +252,11 @@ def handle_function_call(self, run):
print(tool_outputs)

# get the current date and time in the format, Saturday, June 24, 2023 6:14:14 PM
def get_formatted_date(self):
def get_current_datetime(self, arguments):
return datetime.now().strftime("%A, %B %d, %Y %I:%M:%S %p")

# get vehicle vehicle type (e.g. "Copter", "Plane", "Rover", "Boat", etc)
def get_vehicle_type(self):
def get_vehicle_type(self, arguments):
# get vehicle type from latest HEARTBEAT message
hearbeat_msg = self.mpstate.master().messages.get('HEARTBEAT', None)
vehicle_type_str = "unknown"
Expand Down Expand Up @@ -426,13 +327,10 @@ def get_mode_mapping(self, arguments):
mode_list.append({"name": mname.upper(), "number": mnumber})

# return list of modes
try:
return json.dumps(mode_list)
except:
return "get_mode_mapping: failed to convert mode list to json"
return mode_list

# get vehicle state including armed, mode
def get_vehicle_state(self):
def get_vehicle_state(self, arguments):
# get mode from latest HEARTBEAT message
hearbeat_msg = self.mpstate.master().messages.get('HEARTBEAT', None)
if hearbeat_msg is None:
Expand All @@ -446,7 +344,7 @@ def get_vehicle_state(self):
}

# return the vehicle's location and yaw
def get_vehicle_location_and_yaw(self):
def get_vehicle_location_and_yaw(self, arguments):
lat_deg = 0
lon_deg = 0
alt_amsl_m = 0
Expand Down Expand Up @@ -547,7 +445,7 @@ def send_mavlink_set_position_target_global_int(self, arguments):
return "set_position_target_global_int sent"

# get a list of mavlink message names that can be retrieved using the get_mavlink_message function
def get_available_mavlink_messages(self):
def get_available_mavlink_messages(self, arguments):
# check if no messages available
if self.mpstate.master().messages is None or len(self.mpstate.master().messages) == 0:
return "get_available_mavlink_messages: no messages available"
Expand All @@ -560,10 +458,7 @@ def get_available_mavlink_messages(self):
mav_msg_names.append(msg)

# return list of message names
try:
return json.dumps(mav_msg_names)
except:
return "get_available_mavlink_messages: failed to convert message name list to json"
return mav_msg_names

# get a mavlink message including all fields and values sent by the vehicle
def get_mavlink_message(self, arguments):
Expand All @@ -580,11 +475,8 @@ def get_mavlink_message(self, arguments):
if mav_msg is None:
return "get_mavlink_message: message not found"

# convert message to json
try:
return json.dumps(mav_msg.to_dict())
except:
return "get_mavlink_message: failed to convert message to json"
# return message
return mav_msg.to_dict()

# get all available parameters names and their values
def get_all_parameters(self, arguments):
Expand All @@ -594,10 +486,7 @@ def get_all_parameters(self, arguments):
param_list = {}
for param_name in sorted(self.mpstate.mav_param.keys()):
param_list[param_name] = self.mpstate.mav_param.get(param_name)
try:
return json.dumps(param_list)
except:
return "get_all_parameters: failed to convert parameter list to json"
return param_list

# get a vehicle parameter's value
def get_parameter(self, arguments):
Expand Down Expand Up @@ -626,10 +515,7 @@ def get_parameter(self, arguments):
return "get_parameter: " + param_name + " parameter not found"
param_list[param_name] = param_value

try:
return json.dumps(param_list)
except:
return "get_parameter: failed to convert parameter list to json"
return param_list

# set a vehicle parameter's value
def set_parameter(self, arguments):
Expand Down Expand Up @@ -682,10 +568,7 @@ def get_wakeup_timers(self, arguments):
matching_timers.append(wakeup_timer)

# return matching timers
try:
return json.dumps(matching_timers)
except:
return "get_wakeup_timers: failed to convert wakeup timer list to json"
return matching_timers

# delete wake timers
def delete_wakeup_timers(self, arguments):
Expand Down

0 comments on commit 3839df3

Please sign in to comment.