-
Notifications
You must be signed in to change notification settings - Fork 0
/
tools.py
104 lines (93 loc) · 3.78 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from langchain.agents import Tool, tool
from llm_chain import get_cypher_chain
from knowledge_graph import GraphDBManager
db_manager = GraphDBManager(os.getenv('uri'), os.getenv('username'), os.getenv('password'))
@tool
def get_iso_tools(character: str):
"""Given a chacaracter name, gets their ISO values"""
character = character.title()
iso_stats = db_manager.get_character_iso_stats(character)
results = []
for record in iso_stats:
result = f"Matrix: {record['Matrix']}, Active: {record['Active']}, Iso Stat Value: {record['IsoValue']}"
#print(result)
results.append(result)
return results
@tool
def get_char_by_attributes(attributes: str, top: bool=True, limit: int=10) -> list:
"""Given attribute name, get X number of characters, ordered by top (true/false)"""
top_attr_char = db_manager.find_characters_by_attribute(attributes, top=top, limit=limit)
results = []
print("Top Speed Characters:")
for character in top_attr_char:
result = f"Character: {character['CharacterName']}, Speed: {character['AttributeValue']}"
print(result)
results.append(result)
return results
@tool
def get_char_by_trait(trait: str) -> list:
"""Given a trait name, get all the characters for this trait"""
trait = trait.upper()
traited_character = db_manager.find_characters_by_trait(trait)
results = []
for character in traited_character:
result = f"Trait: {character['CharacterName']}"
print(result)
results.append(result)
return results
@tool
def get_related_char(character_name:str) -> list:
"""Given a character, get most related characters"""
character_name = character_name.title()
related_char = db_manager.find_related_characters(character_name=character_name)
return related_char
@tool
def get_common_traits(limit:int=10) -> list:
"""Get most common X number of traits"""
common_trait = db_manager.common_traits_among_high_level_characters(limit=limit)
return common_trait
@tool
def get_ability_desc_of_char(character:str, keyword:str) -> list:
"""Given a character and an ability keyword, return the ability where character has the keyword"""
character = character.title()
keyword = keyword.title()
results = db_manager.find_ability_descriptions_by_keyword(character, keyword)
abilities = []
for result in results:
ability = f"Ability Type: {result['AbilityType']}, Ability Name: {result['AbilityName']}, Description: {result['Description']}"
print(ability)
abilities.append(ability)
return abilities
@tool
def get_ability_cost_of_char(character:str, abilities:list) -> list:
"""Given character name and list of available abilities [basic, ultimate, special, passive],\
return how much start and cost energy for the abilities"""
character = character.title()
energy_details = db_manager.get_ability_energy(character, abilities)
return energy_details
def get_tools():
tools = [
# Tool(
# name="Tasks",
# func=vector_qa.run,
# description="""Useful when you need to answer questions about descriptions of tasks.
# Not useful for counting the number of tasks.
# Use full question as input.
# """,
# ),
get_iso_tools,
get_char_by_attributes,
get_char_by_trait,
get_common_traits,
get_related_char,
get_ability_desc_of_char,
get_ability_cost_of_char,
Tool(
name="Graph",
func=get_cypher_chain().run,
description="""Useful when you need to answer questions about MSF character information ie, speed, power, iso, trait etc. \
which are not already pre-defined tools.""",
),
]
return tools