Skip to content

Commit

Permalink
Remove LLMOutputParser and add documentation to DocumentsDistiller an…
Browse files Browse the repository at this point in the history
…d iEntitiesExtractor classes
  • Loading branch information
lairgiyassir committed Jul 15, 2024
1 parent 16eea93 commit 976d64a
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 81 deletions.
32 changes: 32 additions & 0 deletions itext2kg/documents_distiller/documents_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,33 @@


class DocumentsDisiller:
"""
A class designed to distill essential information from multiple documents into a combined
structure, using natural language processing tools to extract and consolidate information.
"""
def __init__(self, openai_api_key:str, model_name:str = "gpt-4-0125-preview", temperature:str=0) -> None:
"""
Initializes the DocumentsDistiller with specified API key, model name, and operational parameters.
Args:
openai_api_key (str): The API key for accessing OpenAI services.
model_name (str): The model name for the Chat API.
temperature (float): The temperature setting for the Chat API's responses.
"""
self.temperature = temperature
self.langchain_output_parser = LangchainOutputParser(openai_api_key=openai_api_key, model_name=model_name, temperature=temperature)

@staticmethod
def __combine_dicts(dict_list:List[dict]):
"""
Combine a list of dictionaries into a single dictionary, merging values based on their types.
Args:
dict_list (List[dict]): A list of dictionaries to combine.
Returns:
dict: A combined dictionary with merged values.
"""
combined_dict = {}

for d in dict_list:
Expand All @@ -32,6 +53,17 @@ def __combine_dicts(dict_list:List[dict]):


def distill(self, documents: List[str], output_data_structure, IE_query:str) -> dict:
"""
Distill information from multiple documents based on a specific information extraction query.
Args:
documents (List[str]): A list of documents from which to extract information.
output_data_structure: The data structure definition for formatting the output JSON.
IE_query (str): The query to provide to the language model for extracting information.
Returns:
dict: A dictionary representing distilled information from all documents.
"""
output_jsons = list(
map(
lambda context: self.langchain_output_parser.extract_information_as_json_for_context(
Expand Down
65 changes: 65 additions & 0 deletions itext2kg/graph_integration/graph_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,40 @@
from typing import List

class GraphIntegrator:
"""
A class to integrate and manage graph data in a Neo4j database.
"""
def __init__(self, uri: str, username: str, password: str):
"""
Initializes the GraphIntegrator with database connection parameters.
Args:
uri (str): URI for the Neo4j database.
username (str): Username for database access.
password (str): Password for database access.
"""
self.uri = uri
self.username = username
self.password = password
self.driver = self.connect()

def connect(self):
"""
Establishes a connection to the Neo4j database.
Returns:
A Neo4j driver instance for executing queries.
"""
driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password))
return driver

def run_query(self, query: str):
"""
Runs a Cypher query against the Neo4j database.
Args:
query (str): The Cypher query to run.
"""
session = self.driver.session()
try:
session.run(query)
Expand All @@ -22,17 +45,44 @@ def run_query(self, query: str):

@staticmethod
def transform_embeddings_to_str_list(embeddings:np.array):
"""
Transforms a NumPy array of embeddings into a comma-separated string.
Args:
embeddings (np.array): An array of embeddings.
Returns:
str: A comma-separated string of embeddings.
"""
if embeddings is None:
return ""
return ",".join(list(embeddings.astype("str")))

@staticmethod
def transform_str_list_to_embeddings(embeddings:List[str]):
"""
Transforms a comma-separated string of embeddings back into a NumPy array.
Args:
embeddings (str): A comma-separated string of embeddings.
Returns:
np.array: A NumPy array of embeddings.
"""
if embeddings is None:
return ""
return np.array(embeddings.split(",")).astype(np.float64)

def create_nodes(self, json_graph:dict) -> List[str]:
"""
Constructs Cypher queries for creating nodes in the graph database from a JSON structure.
Args:
json_graph (dict): A dictionary representing the nodes to be created.
Returns:
List[str]: A list of Cypher queries for node creation.
"""
queries = []
for node in json_graph["nodes"]:
properties = []
Expand All @@ -46,6 +96,15 @@ def create_nodes(self, json_graph:dict) -> List[str]:
return queries

def create_relationships(self, json_graph:dict) -> list:
"""
Constructs Cypher queries for creating relationships in the graph database from a JSON structure.
Args:
json_graph (dict): A dictionary representing the relationships to be created.
Returns:
List[str]: A list of Cypher queries for relationship creation.
"""
rels = []
for rel in json_graph["relationships"]:
property_statements = ' '.join(
Expand All @@ -60,6 +119,12 @@ def create_relationships(self, json_graph:dict) -> list:


def visualize_graph(self, json_graph:dict) -> None:
"""
Runs the necessary queries to visualize a graph structure from a JSON input.
Args:
json_graph (dict): A dictionary containing the graph structure.
"""
self.connect()

nodes, relationships = (
Expand Down
40 changes: 21 additions & 19 deletions itext2kg/graph_integration/itext2kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List
from ..ientities_extraction import iEntitiesExtractor
from ..irelation_extraction import iRelationsExtractor
from ..utils import Matcher, DataHandler
from ..utils import Matcher, DataHandler, LangchainOutputParser

class iText2KG:
def __init__(self, openai_api_key:str, embeddings_model_name :str = "text-embedding-3-large", model_name:str = "gpt-4-turbo", temperature:float = 0, sleep_time:int=5) -> None:
Expand All @@ -20,6 +20,7 @@ def __init__(self, openai_api_key:str, embeddings_model_name :str = "text-embedd

self.data_handler = DataHandler()
self.matcher = Matcher()
self.langchain_output_parser = LangchainOutputParser(openai_api_key=openai_api_key,)


def extract_entities_for_all_sections(self, sections:List[str], ent_threshold = 0.8):
Expand All @@ -40,38 +41,26 @@ def extract_relations_for_all_sections(self, sections:List[str], entities, rel_t

global_relationships = self.irelations_extractor.extract_relations(context=sections[0], entities = entities)

relations_with_isolated_entities = self.data_handler.find_relations_with_isolated_entities(global_entities=entities, relations=global_relationships)
if relations_with_isolated_entities:
corrected_relations = self.irelations_extractor.correct_relations_for_isolated_entities(context=sections[0], entities=entities, relations_with_isolated_entities=relations_with_isolated_entities)
global_relationships = [rel for rel in global_relationships if rel not in relations_with_isolated_entities] + [corrected_relations]


isolated_entities = self.data_handler.find_isolated_entities(global_entities=entities, relations=global_relationships)
if isolated_entities:
corrected_relations = self.irelations_extractor.extract_relations_for_isolated_entities(context=sections[0], isolated_entities=isolated_entities)
global_relationships.extend(corrected_relations)


global_relationships = self.data_handler.match_relations_with_isolated_entities(global_entities=entities, relations=global_relationships, matcher= lambda ent:self.matcher.find_match(ent, entities, match_type="entity", threshold=0.5), embedding_calculator= lambda ent:self.langchain_output_parser.calculate_embeddings(ent))


for i in range(1, len(sections)):
print("[INFO] Extracting Relations from the Document", i+1)
entities = self.irelations_extractor.extract_relations(context= sections[i], entities=entities)
processed_relationships, global_relationships_ = self.matcher.process_lists(list1 = entities, list2=global_relationships, for_entity_or_relation="relation", threshold = rel_threshold)

print("proce", processed_relationships)

relations_with_isolated_entities = self.data_handler.find_relations_with_isolated_entities(global_entities=entities, relations=processed_relationships)
if relations_with_isolated_entities:
corrected_relations = self.irelations_extractor.correct_relations_for_isolated_entities(context=sections[i], entities=entities, relations_with_isolated_entities=relations_with_isolated_entities)
processed_relationships = [rel for rel in processed_relationships if rel not in relations_with_isolated_entities] + [corrected_relations]

print("first case corrected ...", corrected_relations)


isolated_entities = self.data_handler.find_isolated_entities(global_entities=entities, relations=processed_relationships)
if isolated_entities:
corrected_relations = self.irelations_extractor.extract_relations_for_isolated_entities(context=sections[i], isolated_entities=isolated_entities)
print("second case corrected ...", corrected_relations)
processed_relationships.extend(corrected_relations)

processed_relationships = self.data_handler.match_relations_with_isolated_entities(global_entities=entities, relations=processed_relationships, matcher= lambda ent:self.matcher.find_match(ent, entities, match_type="entity", threshold=0.5), embedding_calculator= lambda ent:self.langchain_output_parser.calculate_embeddings(ent))

global_relationships.extend(processed_relationships)
#return self.data_handler.handle_data(global_relationships, data_type="relation")
return global_relationships
Expand All @@ -83,6 +72,12 @@ def build_graph(self, sections:List[str], ent_threshold:float = 0.7, rel_thresho
print("[INFO] Extracting Relations from the Document", 1)
global_relationships = self.irelations_extractor.extract_relations(context=sections[0], entities = list(map(lambda w:w["name"], global_entities)))

isolated_entities = self.data_handler.find_isolated_entities(global_entities=global_entities, relations=global_relationships)
if isolated_entities:
corrected_relations = self.irelations_extractor.extract_relations_for_isolated_entities(context=sections[0], isolated_entities=isolated_entities)
global_relationships.extend(corrected_relations)

global_relationships = self.data_handler.match_relations_with_isolated_entities(global_entities=global_entities, relations=global_relationships, matcher= lambda ent:self.matcher.find_match(ent, global_entities, match_type="entity", threshold=0.5), embedding_calculator= lambda ent:self.langchain_output_parser.calculate_embeddings(ent))

for i in range(1, len(sections)):
print("[INFO] Extracting Entities from the Document", i+1)
Expand All @@ -95,6 +90,13 @@ def build_graph(self, sections:List[str], ent_threshold:float = 0.7, rel_thresho
relationships = self.irelations_extractor.extract_relations(context= sections, entities=list(map(lambda w:w["name"], processed_entities)))
processed_relationships, _ = self.matcher.process_lists(list1 = relationships, list2=global_relationships, for_entity_or_relation="relation", threshold=rel_threshold)

isolated_entities = self.data_handler.find_isolated_entities(global_entities=processed_entities, relations=processed_relationships)
if isolated_entities:
corrected_relations = self.irelations_extractor.extract_relations_for_isolated_entities(context=sections[i], isolated_entities=isolated_entities)
processed_relationships.extend(corrected_relations)

processed_relationships = self.data_handler.match_relations_with_isolated_entities(global_entities=processed_entities, relations=processed_relationships, matcher= lambda ent:self.matcher.find_match(ent, processed_entities, match_type="entity", threshold=0.5), embedding_calculator= lambda ent:self.langchain_output_parser.calculate_embeddings(ent))

global_relationships.extend(processed_relationships)

return self.data_handler.handle_data(global_entities, data_type="entity"), self.data_handler.handle_data(global_relationships, data_type="relation")
Expand Down
39 changes: 37 additions & 2 deletions itext2kg/ientities_extraction/ientities_extractor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from typing import List
from ..utils import LangchainOutputParser, EntitiesExtractor
from ..utils import Matcher

class iEntitiesExtractor():
"""
A class to extract entities from text using natural language processing tools and embeddings.
"""
def __init__(self, openai_api_key:str, embeddings_model_name :str = "text-embedding-3-large", model_name:str = "gpt-4-turbo", temperature:float = 0, sleep_time:int=5) -> None:
"""
Initializes the iEntitiesExtractor with specified API key, models, and operational parameters.
Args:
openai_api_key (str): The API key for accessing OpenAI services.
embeddings_model_name (str): The model name for text embeddings.
model_name (str): The model name for the Chat API.
temperature (float): The temperature setting for the Chat API's responses.
sleep_time (int): The time to wait (in seconds) when encountering rate limits or errors.
"""
self.langchain_output_parser = LangchainOutputParser(openai_api_key=openai_api_key,
embeddings_model_name=embeddings_model_name,
model_name=model_name,
Expand All @@ -13,6 +24,19 @@ def __init__(self, openai_api_key:str, embeddings_model_name :str = "text-embedd


def __add_embeddings_as_property(self, entity:dict, property_name = "properties", embeddings_name = "embeddings", entity_name_name = "name", embeddings:bool = True):
"""
Add embeddings as a property to the given entity dictionary.
Args:
entity (dict): The entity to which embeddings will be added.
property_name (str): The key under which embeddings will be stored.
embeddings_name (str): The name of the embeddings key.
entity_name_name (str): The key name for the entity's name.
embeddings (bool): A flag to determine whether to calculate embeddings.
Returns:
dict: The entity dictionary with added embeddings.
"""
entity = entity.copy()

entity[entity_name_name] = entity[entity_name_name].lower().replace("_", " ").replace("-", " ")
Expand All @@ -26,7 +50,18 @@ def __add_embeddings_as_property(self, entity:dict, property_name = "properties"


def extract_entities(self, context: str, embeddings: bool = True, property_name = "properties", entity_name_name = "name"):
"""
Extract entities from a given context and optionally add embeddings to each.
Args:
context (str): The textual context from which entities will be extracted.
embeddings (bool): A flag to determine whether to add embeddings to the extracted entities.
property_name (str): The property name under which embeddings will be stored in the entity.
entity_name_name (str): The key name for the entity's name.
Returns:
List[dict]: A list of extracted entities with optional embeddings.
"""
entities = self.langchain_output_parser.extract_information_as_json_for_context(context=context, output_data_structure=EntitiesExtractor)
print(entities)

Expand Down
Loading

0 comments on commit 976d64a

Please sign in to comment.