Skip to content

Commit

Permalink
use date-period time, optimize translation process
Browse files Browse the repository at this point in the history
  • Loading branch information
k-okada authored and sktometometo committed Jun 21, 2023
1 parent 00989ff commit 421a565
Showing 1 changed file with 94 additions and 46 deletions.
140 changes: 94 additions & 46 deletions database_talker/scripts/hoge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import json
import os
import random
import re
import rospkg
import shutil
import sys
import yaml
import tempfile
import time
import traceback
Expand Down Expand Up @@ -91,20 +93,20 @@ def __init__(self):
rospy.loginfo("all done, ready")


def make_reply(self, message, lang="en"):
rospy.logwarn("Run make_reply({})".format(message))
def make_reply(self, message, lang="en", startdate=datetime.datetime.now(JST)-datetime.timedelta(hours=24), duration=datetime.timedelta(hours=24) ):
enddate = startdate+duration
rospy.logwarn("Run make_reply({} from {} to {})".format(message, startdate, enddate))
query = self.text_to_salience(message)
rospy.logwarn("query using salience word '{}'".format(query))
# look for images
try:
# get chat message
timestamp = datetime.datetime.now(JST)
results, chat_msgs = self.query_dialogflow(query, timestamp, threshold=0.25)
retry = 0
while retry < 3 and len(results) == 0 and len(chat_msgs.metas) > 0:
meta = json.loads(chat_msgs.metas[-1].pairs[0].second)
results, chat_msgs = self.query_dialogflow(query, datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST))
retry = retry + 1
results, chat_msgs = self.query_dialogflow(query, startdate, enddate, threshold=0.25)
# retry = 0
# while retry < 3 and len(results) == 0 and len(chat_msgs.metas) > 0:
# meta = json.loads(chat_msgs.metas[-1].pairs[0].second)
# results, chat_msgs = self.query_dialogflow(query, datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST))
# retry = retry + 1
# sort based on similarity with 'query'
chat_msgs_sorted = sorted(results, key=lambda x: x['similarity'], reverse=True)

Expand All @@ -115,27 +117,32 @@ def make_reply(self, message, lang="en"):
msg = chat_msgs_sorted[0]['msg']
meta = chat_msgs_sorted[0]['meta']
text = chat_msgs_sorted[0]['message']
timestamp = chat_msgs_sorted[0]['timestamp']
startdate = chat_msgs_sorted[0]['timestamp']
action = chat_msgs_sorted[0]['action']
similarity = chat_msgs_sorted[0]['similarity']
# query chat to get response
#meta = json.loads(chat_msgs_sorted[0]['meta'].pairs[0].second)
# text = msg.message.argument_text or msg.message.text
# timestamp = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)
rospy.loginfo("Found message '{}'({}) at {}, corresponds to query '{}' with {:2f}%".format(text, action, timestamp.strftime('%Y-%m-%d %H:%M:%S'), query, similarity))
# startdate = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)
rospy.loginfo("Found message '{}'({}) at {}, corresponds to query '{}' with {:2f}%".format(text, action, startdate.strftime('%Y-%m-%d %H:%M:%S'), query, similarity))

# query images when chat was received (+- 30 min)
start_time = timestamp-datetime.timedelta(minutes=30)
end_time = timestamp+datetime.timedelta(minutes=30)
# query images when chat was received
start_time = startdate # startdate is updated with found chat space
end_time = enddate # enddate is not modified within this function, it is given from chat
results = self.query_images_and_classify(query=query, start_time=start_time, end_time=end_time)

if len(results) > 0:
end_time = results[-1]['timestamp']
# no images found
if len(results) == 0:
return {'text': '記憶がありません🤯'}

end_time = results[-1]['timestamp']

# sort
results = sorted(results, key=lambda x: x['similarities'], reverse=True)
rospy.loginfo("Probabilities of all images {}".format(list(map(lambda x: (x['label'], x['similarities']), results))))
rospy.loginfo("Probabilities of all images {}".format(list(map(lambda x: (x['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), x['similarities']), results))))
best_result = results[0]

'''
# if probability is too low, try again
while len(results) > 0 and results[0]['similarities'] < 0.25:
Expand All @@ -151,32 +158,28 @@ def make_reply(self, message, lang="en"):
best_result = results[0]
rospy.loginfo("Found '{}' image with {:0.2f} % simiarity at {}".format(best_result['label'], best_result['similarities'], best_result['timestamp'].strftime('%Y-%m-%d %H:%M:%S')))
'''

## make prompt
goal = VQATaskGoal()
goal.compressed_image = best_result['image']

# unusual objects
if random.randint(0,1) == 1:
goal.questions = ['what unusual things can be seen?']
reaction = 'and you saw '
else:
goal.questions = ['what is the atmosphere of this place?']
reaction = 'and the atmosphere of the scene was '

# get vqa result
self.vqa_ac.send_goal(goal)
self.vqa_ac.wait_for_result()
result = self.vqa_ac.get_result()
reaction += result.result.result[0].answer
reaction = self.describe_image_scene(best_result['image'])
if len(chat_msgs_sorted) > 0 and chat_msgs_sorted[0]['action'] and 'action' in chat_msgs_sorted[0]:
reaction += " and you felt " + chat_msgs_sorted[0]['action']
rospy.loginfo("reaction = {}".format(reaction))

# make prompt
prompt = 'if you are a pet and someone tells you \"' + message + '\" when we went together, ' + \
reaction + ' in your memory of that moment, what would you reply? '+ \
'and ' + reaction + ' in your memory of that moment, what would you reply? '+ \
'Show only the reply in {lang}'.format(lang={'en': 'English', 'ja':'Japanese'}[lang])
result = self.completion(prompt=prompt,temperature=0)
loop = 0
result = None
while loop < 3 and result is None:
try:
result = self.completion(prompt=prompt,temperature=0)
except rospy.ServiceException as e:
rospy.logerr("Service call failed: %s"%e)
result = None
loop += 1
result.text = result.text.lstrip()
rospy.loginfo("prompt = {}".format(prompt))
rospy.loginfo("result = {}".format(result))
# pubish as card
Expand All @@ -203,19 +206,27 @@ def write_image_with_annotation(self, filename, best_result, prompt):
rospy.logwarn("save images to {}".format(filename))


def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
rospy.logwarn("Query dialogflow until {}".format(end_time))
meta_query= {'inserted_at': {"$lt": end_time}}
def query_dialogflow(self, query, start_time, end_time, limit=30, threshold=0.0):
rospy.logwarn("Query dialogflow from {} until {}".format(start_time, end_time))
meta_query= {'inserted_at': {"$lt": end_time, "$gt": start_time}}
meta_tuple = (StringPair(MongoQueryMsgRequest.JSON_QUERY, json.dumps(meta_query, default=json_util.default)),)
chat_msgs = self.query(database = 'jsk_robot_lifelog',
collection = self.robot_name,
# type = 'google_chat_ros/MessageEvent',
type = 'dialogflow_task_executive/DialogTextActionResult',
single = False,
limit = limit,
# limit = limit,
meta_query = StringPairList(meta_tuple),
sort_query = StringPairList([StringPair('_meta.inserted_at', '-1')]))

# optimization... send translate once
messages = ''
for msg, meta in zip(chat_msgs.messages, chat_msgs.metas):
msg = deserialise_message(msg)
message = msg.result.response.query.replace('\n','')
messages += message + '\n'
messages = self.translate(messages, dest="en").text.split('\n')

# show chats
results = []
for msg, meta in zip(chat_msgs.messages, chat_msgs.metas):
Expand All @@ -224,7 +235,8 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
timestamp = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)
# message = msg.message.argument_text or msg.message.text
message = msg.result.response.query
message_translate = self.translate(message, dest="en").text
#message_translate = self.translate(message, dest="en").text
message_translate = messages.pop(0).strip()
result = {'message': message,
'message_translate': message_translate,
'timestamp': timestamp,
Expand All @@ -233,9 +245,9 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
'msg': msg,
'meta': meta}
if msg.result.response.action in ['make_reply', 'input.unknown']:
rospy.logwarn("Found dialogflow messages {} at {} but skipping (action:{})".format(result['message'], result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), msg.result.response.action))
rospy.logwarn("Found dialogflow messages {}({}) at {} but skipping (action:{})".format(result['message'], result['message_translate'], result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), msg.result.response.action))
else:
rospy.logwarn("Found dialogflow messages {}({}) ({}) at {} ({}:{:.2f})".format(result['message'], result['message_translate'], msg.result.response.action, result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), query, result['similarity']))
rospy.loginfo("Found dialogflow messages {}({}) ({}) at {} ({}:{:.2f})".format(result['message'], result['message_translate'], msg.result.response.action, result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), query, result['similarity']))
if ( result['similarity'] > threshold):
results.append(result)
else:
Expand All @@ -245,7 +257,7 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
return results, chat_msgs


def query_images_and_classify(self, query, start_time, end_time, limit=30):
def query_images_and_classify(self, query, start_time, end_time, limit=10):
rospy.logwarn("Query images from {} to {}".format(start_time, end_time))
meta_query= {#'input_topic': '/spot/camera/hand_color/image/compressed/throttled',
'inserted_at': {"$gt": start_time, "$lt": end_time}}
Expand Down Expand Up @@ -285,6 +297,24 @@ def query_images_and_classify(self, query, start_time, end_time, limit=30):
# we do not sorty by probabilites, becasue we also need oldest timestamp
return results

def describe_image_scene(self, image):
goal = VQATaskGoal()
goal.compressed_image = image

# unusual objects
if random.randint(0,1) == 1:
goal.questions = ['what unusual things can be seen?']
reaction = 'you saw '
else:
goal.questions = ['what is the atmosphere of this place?']
reaction = 'the atmosphere of the scene was '

# get vqa result
self.vqa_ac.send_goal(goal)
self.vqa_ac.wait_for_result()
result = self.vqa_ac.get_result()
reaction += result.result.result[0].answer
return reaction

def publish_google_chat_card(self, text, space, filename=None):
goal = SendMessageGoal()
Expand Down Expand Up @@ -347,9 +377,27 @@ def cb(self, msg):
if result.response.action == 'input.unknown':
self.publish_google_chat_card("🤖", space)
elif result.response.action == 'make_reply':
self.publish_google_chat_card("・・・", space)

parameters = yaml.safe_load(result.response.parameters)
startdate=datetime.datetime.now(JST)-datetime.timedelta(hours=24)
duration=datetime.timedelta(hours=24)
if parameters['date']:
startdate = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date']), "%Y-%m-%dT%H:%M:%S%z")
duration = datetime.timedelta(hours=24)
if parameters['date-period']:
startdate = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date-period']['startDate']), "%Y-%m-%dT%H:%M:%S%z")
duration = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date-period']['endDate']), "%Y-%m-%dT%H:%M:%S%z") - startdate
print(startdate)
print(duration)
translated = self.translate(result.response.query, dest="en")
ret = self.make_reply(translated.text, translated.src)
self.publish_google_chat_card(ret['text'], space, ret['filename'])
ret = self.make_reply(translated.text, translated.src, startdate=startdate, duration=duration)
if 'filename' in ret:
# upload text first, then upload images
self.publish_google_chat_card(ret['text'], space)
self.publish_google_chat_card('', space, ret['filename'])
else:
self.publish_google_chat_card(ret['text'], space)
else:
self.publish_google_chat_card(result.response.response, space)

Expand Down

0 comments on commit 421a565

Please sign in to comment.