-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
26e86ac
commit 6997315
Showing
7 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
WEAVIATE_HOST= | ||
WEAVIATE_PORT= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import weaviate | ||
import os | ||
|
||
from data.lecture.lectures import Lectures | ||
from data.repository.repositories import Repositories | ||
|
||
|
||
class VectorDatabase: | ||
def __init__(self): | ||
weaviate_host = os.getenv("WEAVIATE_HOST") | ||
weaviate_port = os.getenv("WEAVIATE_PORT") | ||
assert weaviate_host, "WEAVIATE_HOST environment variable must be set" | ||
assert weaviate_port, "WEAVIATE_PORT environment variable must be set" | ||
assert ( | ||
weaviate_port.isdigit() | ||
), "WEAVIATE_PORT environment variable must be an integer" | ||
self._client = weaviate.connect_to_local( | ||
host=weaviate_host, port=int(weaviate_port) | ||
) | ||
self.repositories = Repositories(self._client) | ||
self.lectures = Lectures(self._client) | ||
|
||
def __del__(self): | ||
# Close the connection to Weaviate when the object is deleted | ||
self._client.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import weaviate.classes as wvc | ||
from weaviate import WeaviateClient | ||
from weaviate.collections import Collection | ||
|
||
|
||
COLLECTION_NAME = "LectureSlides" | ||
|
||
|
||
# Potential improvement: | ||
# Don't store the names of the courses, lectures, and units for every single chunk | ||
# These can be looked up via the IDs when needed - query Artemis? or store locally? | ||
|
||
|
||
class LectureSlideChunk: | ||
PAGE_CONTENT = "page_content" # The only property which will be embedded | ||
COURSE_ID = "course_id" | ||
COURSE_NAME = "course_name" | ||
LECTURE_ID = "lecture_id" | ||
LECTURE_NAME = "lecture_name" | ||
LECTURE_UNIT_ID = "lecture_unit_id" # The attachment unit ID in Artemis | ||
LECTURE_UNIT_NAME = "lecture_unit_name" | ||
FILENAME = "filename" | ||
PAGE_NUMBER = "page_number" | ||
|
||
|
||
def init_schema(client: WeaviateClient) -> Collection: | ||
if client.collections.exists(COLLECTION_NAME): | ||
return client.collections.get(COLLECTION_NAME) | ||
return client.collections.create( | ||
name=COLLECTION_NAME, | ||
vectorizer_config=wvc.config.Configure.Vectorizer.none(), # We do not want to vectorize the text automatically | ||
# HNSW is preferred over FLAT for large amounts of data, which is the case here | ||
vector_index_config=wvc.config.Configure.VectorIndex.hnsw( | ||
distance_metric=wvc.config.VectorDistances.COSINE # select preferred distance metric | ||
), | ||
# The properties are like the columns of a table in a relational database | ||
properties=[ | ||
wvc.config.Property( | ||
name=LectureSlideChunk.PAGE_CONTENT, | ||
description="The original text content from the slide", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.COURSE_ID, | ||
description="The ID of the course", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.COURSE_NAME, | ||
description="The name of the course", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.LECTURE_ID, | ||
description="The ID of the lecture", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.LECTURE_NAME, | ||
description="The name of the lecture", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.LECTURE_UNIT_ID, | ||
description="The ID of the lecture unit", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.LECTURE_UNIT_NAME, | ||
description="The name of the lecture unit", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.FILENAME, | ||
description="The name of the file from which the slide was extracted", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
wvc.config.Property( | ||
name=LectureSlideChunk.PAGE_NUMBER, | ||
description="The page number of the slide", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import json | ||
import os | ||
import time | ||
|
||
import fitz # PyMuPDF | ||
import openai | ||
import weaviate | ||
from unstructured.cleaners.core import clean | ||
import weaviate.classes as wvc | ||
|
||
from data.lecture.lecture_schema import init_schema, COLLECTION_NAME, LectureSlideChunk | ||
|
||
|
||
def chunk_files(subdirectory_path, subdirectory): | ||
data = [] | ||
# Process each PDF file in this subdirectory | ||
for filename in os.listdir(subdirectory_path): | ||
if not filename.endswith(".pdf"): | ||
continue | ||
file_path = os.path.join(subdirectory_path, filename) | ||
# Open the PDF | ||
with fitz.open(file_path) as doc: | ||
for page_num in range(len(doc)): | ||
page_text = doc[page_num].get_text() | ||
page_text = clean(page_text, bullets=True, extra_whitespace=True) | ||
data.append( | ||
{ | ||
LectureSlideChunk.PAGE_CONTENT: page_text, | ||
LectureSlideChunk.COURSE_ID: "", | ||
LectureSlideChunk.LECTURE_ID: "", | ||
LectureSlideChunk.LECTURE_NAME: "", | ||
LectureSlideChunk.LECTURE_UNIT_ID: "", | ||
LectureSlideChunk.LECTURE_UNIT_NAME: "", | ||
LectureSlideChunk.FILENAME: file_path, | ||
LectureSlideChunk.PAGE_NUMBER: "", | ||
} | ||
) | ||
return data | ||
|
||
|
||
class Lectures: | ||
|
||
def __init__(self, client: weaviate.WeaviateClient): | ||
self.collection = init_schema(client) | ||
|
||
def ingest(self, lectures): | ||
pass | ||
|
||
def search(self, query, k=3, filter=None): | ||
pass | ||
|
||
def batch_import(self, directory_path, subdirectory): | ||
data = chunk_files(directory_path, subdirectory) | ||
with self.collection.batch.dynamic() as batch: | ||
for i, properties in enumerate(data): | ||
embeddings_created = False | ||
for j in range(5): # max 5 retries | ||
if not embeddings_created: | ||
try: | ||
batch.add_data_object(properties, COLLECTION_NAME) | ||
embeddings_created = True # Set flag to True on success | ||
break # Break the loop as embedding creation was successful | ||
except openai.error.RateLimitError: | ||
time.sleep(2**j) # wait 2^j seconds before retrying | ||
print("Retrying import...") | ||
else: | ||
break # Exit loop if embeddings already created | ||
# Raise an error if embeddings were not created after retries | ||
if not embeddings_created: | ||
raise RuntimeError("Failed to create embeddings.") | ||
|
||
def query_database(self, user_message: str, lecture_id: int = None): | ||
response = self.collection.query.near_text( | ||
near_text=user_message, | ||
filters=( | ||
wvc.query.Filter.by_property(LectureSlideChunk.LECTURE_ID).equal( | ||
lecture_id | ||
) | ||
if lecture_id | ||
else None | ||
), | ||
return_properties=[ | ||
LectureSlideChunk.PAGE_CONTENT, | ||
LectureSlideChunk.COURSE_NAME, | ||
], | ||
limit=5, | ||
) | ||
print(json.dumps(response, indent=2)) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import weaviate | ||
|
||
from data.repository.repository_schema import init_schema | ||
|
||
|
||
class Repositories: | ||
|
||
def __init__(self, client: weaviate.WeaviateClient): | ||
self.collection = init_schema(client) | ||
|
||
def ingest(self, repositories: dict[str, str]): | ||
pass | ||
|
||
def search(self, query, k=3, filter=None): | ||
pass | ||
|
||
def create_tree_structure(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import weaviate.classes as wvc | ||
from weaviate import WeaviateClient | ||
from weaviate.collections import Collection | ||
|
||
|
||
COLLECTION_NAME = "StudentRepository" | ||
|
||
|
||
class RepositoryChunk: | ||
CONTENT = "content" # The only property which will be embedded | ||
COURSE_ID = "course_id" | ||
EXERCISE_ID = "exercise_id" | ||
REPOSITORY_ID = "repository_id" | ||
FILEPATH = "filepath" | ||
|
||
|
||
def init_schema(client: WeaviateClient) -> Collection: | ||
if client.collections.exists(COLLECTION_NAME): | ||
return client.collections.get(COLLECTION_NAME) | ||
return client.collections.create( | ||
name=COLLECTION_NAME, | ||
vectorizer_config=wvc.config.Configure.Vectorizer.none(), # We do not want to vectorize the text automatically | ||
# HNSW is preferred over FLAT for large amounts of data, which is the case here | ||
vector_index_config=wvc.config.Configure.VectorIndex.hnsw( | ||
distance_metric=wvc.config.VectorDistances.COSINE # select preferred distance metric | ||
), | ||
# The properties are like the columns of a table in a relational database | ||
properties=[ | ||
wvc.config.Property( | ||
name=RepositoryChunk.CONTENT, | ||
description="The content of this chunk of code", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
wvc.config.Property( | ||
name=RepositoryChunk.COURSE_ID, | ||
description="The ID of the course", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
wvc.config.Property( | ||
name=RepositoryChunk.EXERCISE_ID, | ||
description="The ID of the exercise", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
wvc.config.Property( | ||
name=RepositoryChunk.REPOSITORY_ID, | ||
description="The ID of the repository", | ||
data_type=wvc.config.DataType.INT, | ||
), | ||
wvc.config.Property( | ||
name=RepositoryChunk.FILEPATH, | ||
description="The filepath of the code", | ||
data_type=wvc.config.DataType.TEXT, | ||
), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters