-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1077 from MichaelDecent/comm_pkg_14
Add Swarmauri MlmEmbedding and MlmVectorStore pkg
- Loading branch information
Showing
7 changed files
with
456 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Swarmauri Example Community Package |
58 changes: 58 additions & 0 deletions
58
pkgs/community/swarmauri_vectorstore_communitymlm/pyproject.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
[tool.poetry] | ||
name = "swarmauri_vectorstore_communitymlm" | ||
version = "0.6.0.dev1" | ||
description = "Swarmauri MLM Vector Store" | ||
authors = ["Jacob Stewart <[email protected]>"] | ||
license = "Apache-2.0" | ||
readme = "README.md" | ||
repository = "http://github.com/swarmauri/swarmauri-sdk" | ||
classifiers = [ | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12" | ||
] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10,<3.13" | ||
|
||
# Swarmauri | ||
swarmauri_core = { path = "../../core" } | ||
swarmauri_base = { path = "../../base" } | ||
|
||
# Dependencies | ||
transformers = ">=4.45.0" | ||
torch = "^2.4.1" | ||
|
||
|
||
[tool.poetry.group.dev.dependencies] | ||
flake8 = "^7.0" | ||
pytest = "^8.0" | ||
pytest-asyncio = ">=0.24.0" | ||
pytest-xdist = "^3.6.1" | ||
pytest-json-report = "^1.5.0" | ||
python-dotenv = "*" | ||
requests = "^2.32.3" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.pytest.ini_options] | ||
norecursedirs = ["combined", "scripts"] | ||
|
||
markers = [ | ||
"test: standard test", | ||
"unit: Unit tests", | ||
"integration: Integration tests", | ||
"acceptance: Acceptance tests", | ||
"experimental: Experimental tests" | ||
] | ||
log_cli = true | ||
log_cli_level = "INFO" | ||
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" | ||
log_cli_date_format = "%Y-%m-%d %H:%M:%S" | ||
asyncio_default_fixture_loop_scope = "function" | ||
|
||
[tool.poetry.plugins."swarmauri.vectorstores"] | ||
ExampleCommunityAgent = "swarmauri_vectorstore_communitymlm.ExampleCommunityAgent:ExampleCommunityAgent" |
227 changes: 227 additions & 0 deletions
227
...ity/swarmauri_vectorstore_communitymlm/swarmauri_vectorstore_communitymlm/MlmEmbedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
from typing import List, Union, Any, Literal | ||
import logging | ||
from pydantic import PrivateAttr | ||
import torch | ||
from torch.utils.data import TensorDataset, DataLoader | ||
from torch.optim import AdamW | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
|
||
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase | ||
from swarmauri_standard.vectors.Vector import Vector | ||
|
||
|
||
class MlmEmbedding(EmbeddingBase): | ||
""" | ||
EmbeddingBase implementation that fine-tunes a Masked Language Model (MLM). | ||
""" | ||
|
||
embedding_name: str = "bert-base-uncased" | ||
batch_size: int = 32 | ||
learning_rate: float = 5e-5 | ||
masking_ratio: float = 0.15 | ||
randomness_ratio: float = 0.10 | ||
epochs: int = 0 | ||
add_new_tokens: bool = False | ||
_tokenizer = PrivateAttr() | ||
_model = PrivateAttr() | ||
_device = PrivateAttr() | ||
_mask_token_id = PrivateAttr() | ||
type: Literal["MlmEmbedding"] = "MlmEmbedding" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self._tokenizer = AutoTokenizer.from_pretrained(self.embedding_name) | ||
self._model = AutoModelForMaskedLM.from_pretrained(self.embedding_name) | ||
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
self._model.to(self._device) | ||
self._mask_token_id = self._tokenizer.convert_tokens_to_ids( | ||
[self._tokenizer.mask_token] | ||
)[0] | ||
|
||
def extract_features(self) -> List[str]: | ||
""" | ||
Extracts the tokens from the vocabulary of the fine-tuned MLM. | ||
Returns: | ||
- List[str]: A list of token strings in the model's vocabulary. | ||
""" | ||
# Get the vocabulary size | ||
vocab_size = len(self._tokenizer) | ||
|
||
# Retrieve the token strings for each id in the vocabulary | ||
token_strings = [ | ||
self._tokenizer.convert_ids_to_tokens(i) for i in range(vocab_size) | ||
] | ||
|
||
return token_strings | ||
|
||
def _mask_tokens(self, encodings): | ||
input_ids = encodings.input_ids.to(self._device) | ||
attention_mask = encodings.attention_mask.to(self._device) | ||
|
||
labels = input_ids.detach().clone() | ||
|
||
probability_matrix = torch.full( | ||
labels.shape, self.masking_ratio, device=self._device | ||
) | ||
special_tokens_mask = [ | ||
self._tokenizer.get_special_tokens_mask( | ||
val, already_has_special_tokens=True | ||
) | ||
for val in labels.tolist() | ||
] | ||
probability_matrix.masked_fill_( | ||
torch.tensor(special_tokens_mask, dtype=torch.bool, device=self._device), | ||
value=0.0, | ||
) | ||
masked_indices = torch.bernoulli(probability_matrix).bool() | ||
|
||
labels[~masked_indices] = -100 | ||
|
||
indices_replaced = ( | ||
torch.bernoulli( | ||
torch.full(labels.shape, self.masking_ratio, device=self._device) | ||
).bool() | ||
& masked_indices | ||
) | ||
input_ids[indices_replaced] = self._mask_token_id | ||
|
||
indices_random = ( | ||
torch.bernoulli( | ||
torch.full(labels.shape, self.randomness_ratio, device=self._device) | ||
).bool() | ||
& masked_indices | ||
& ~indices_replaced | ||
) | ||
random_words = torch.randint( | ||
len(self._tokenizer), labels.shape, dtype=torch.long, device=self._device | ||
) | ||
input_ids[indices_random] = random_words[indices_random] | ||
|
||
return input_ids, attention_mask, labels | ||
|
||
def fit(self, documents: List[Union[str, Any]]): | ||
# Check if we need to add new tokens | ||
if self.add_new_tokens: | ||
new_tokens = self.find_new_tokens(documents) | ||
if new_tokens: | ||
num_added_toks = self._tokenizer.add_tokens(new_tokens) | ||
if num_added_toks > 0: | ||
logging.info(f"Added {num_added_toks} new tokens.") | ||
self.model.resize_token_embeddings(len(self._tokenizer)) | ||
|
||
encodings = self._tokenizer( | ||
documents, | ||
return_tensors="pt", | ||
padding=True, | ||
truncation=True, | ||
max_length=512, | ||
) | ||
input_ids, attention_mask, labels = self._mask_tokens(encodings) | ||
optimizer = AdamW(self._model.parameters(), lr=self.learning_rate) | ||
dataset = TensorDataset(input_ids, attention_mask, labels) | ||
data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) | ||
|
||
self._model.train() | ||
for batch in data_loader: | ||
batch = { | ||
k: v.to(self._device) | ||
for k, v in zip(["input_ids", "attention_mask", "labels"], batch) | ||
} | ||
outputs = self._model(**batch) | ||
loss = outputs.loss | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
self.epochs += 1 | ||
logging.info(f"Epoch {self.epochs} complete. Loss {loss.item()}") | ||
|
||
def find_new_tokens(self, documents): | ||
# Identify unique words in documents that are not in the tokenizer's vocabulary | ||
unique_words = set() | ||
for doc in documents: | ||
tokens = set(doc.split()) # Simple whitespace tokenization | ||
unique_words.update(tokens) | ||
existing_vocab = set(self._tokenizer.get_vocab().keys()) | ||
new_tokens = list(unique_words - existing_vocab) | ||
return new_tokens if new_tokens else None | ||
|
||
def transform(self, documents: List[Union[str, Any]]) -> List[Vector]: | ||
""" | ||
Generates embeddings for a list of documents using the fine-tuned MLM. | ||
""" | ||
self._model.eval() | ||
embedding_list = [] | ||
|
||
for document in documents: | ||
inputs = self._tokenizer( | ||
document, | ||
return_tensors="pt", | ||
padding=True, | ||
truncation=True, | ||
max_length=512, | ||
) | ||
inputs = {k: v.to(self._device) for k, v in inputs.items()} | ||
with torch.no_grad(): | ||
outputs = self._model(**inputs) | ||
# Extract embedding (for simplicity, averaging the last hidden states) | ||
if hasattr(outputs, "last_hidden_state"): | ||
embedding = outputs.last_hidden_state.mean(1) | ||
else: | ||
# Fallback or corrected attribute access | ||
embedding = outputs["logits"].mean(1) | ||
embedding = embedding.cpu().numpy() | ||
embedding_list.append(Vector(value=embedding.squeeze().tolist())) | ||
|
||
return embedding_list | ||
|
||
def fit_transform(self, documents: List[Union[str, Any]], **kwargs) -> List[Vector]: | ||
""" | ||
Fine-tunes the MLM and generates embeddings for the provided documents. | ||
""" | ||
self.fit(documents, **kwargs) | ||
return self.transform(documents) | ||
|
||
def infer_vector(self, data: Union[str, Any], *args, **kwargs) -> Vector: | ||
""" | ||
Generates an embedding for the input data. | ||
Parameters: | ||
- data (Union[str, Any]): The input data, expected to be a textual representation. | ||
Could be a single string or a batch of strings. | ||
""" | ||
# Tokenize the input data and ensure the tensors are on the correct device. | ||
self._model.eval() | ||
inputs = self._tokenizer( | ||
data, return_tensors="pt", padding=True, truncation=True, max_length=512 | ||
) | ||
inputs = {k: v.to(self._device) for k, v in inputs.items()} | ||
|
||
# Generate embeddings using the model | ||
with torch.no_grad(): | ||
outputs = self._model(**inputs) | ||
|
||
if hasattr(outputs, "last_hidden_state"): | ||
# Access the last layer and calculate the mean across all tokens (simple pooling) | ||
embedding = outputs.last_hidden_state.mean(dim=1) | ||
else: | ||
embedding = outputs["logits"].mean(1) | ||
# Move the embeddings back to CPU for compatibility with downstream tasks if necessary | ||
embedding = embedding.cpu().numpy() | ||
|
||
return Vector(value=embedding.squeeze().tolist()) | ||
|
||
def save_model(self, path: str) -> None: | ||
""" | ||
Saves the model and tokenizer to the specified directory. | ||
""" | ||
self._model.save_pretrained(path) | ||
self._tokenizer.save_pretrained(path) | ||
|
||
def load_model(self, path: str) -> None: | ||
""" | ||
Loads the model and tokenizer from the specified directory. | ||
""" | ||
self._model = AutoModelForMaskedLM.from_pretrained(path) | ||
self._tokenizer = AutoTokenizer.from_pretrained(path) | ||
self._model.to(self._device) # Ensure the model is loaded to the correct device |
89 changes: 89 additions & 0 deletions
89
...y/swarmauri_vectorstore_communitymlm/swarmauri_vectorstore_communitymlm/MlmVectorStore.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import List, Union, Literal | ||
from swarmauri_standard.documents.Document import Document | ||
from swarmauri_vectorstore_communitymlm.MlmEmbedding import MlmEmbedding | ||
from swarmauri_standard.distances.CosineDistance import CosineDistance | ||
|
||
from swarmauri_base.vector_stores.VectorStoreBase import VectorStoreBase | ||
from swarmauri_base.vector_stores.VectorStoreRetrieveMixin import ( | ||
VectorStoreRetrieveMixin, | ||
) | ||
from swarmauri_base.vector_stores.VectorStoreSaveLoadMixin import ( | ||
VectorStoreSaveLoadMixin, | ||
) | ||
|
||
|
||
class MlmVectorStore( | ||
VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase | ||
): | ||
type: Literal["MlmVectorStore"] = "MlmVectorStore" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self._embedder = MlmEmbedding() | ||
self._distance = CosineDistance() | ||
self.documents: List[Document] = [] | ||
|
||
def add_document(self, document: Document) -> None: | ||
self.documents.append(document) | ||
documents_text = [_d.content for _d in self.documents if _d.content] | ||
embeddings = self._embedder.fit_transform(documents_text) | ||
|
||
embedded_documents = [ | ||
Document( | ||
id=_d.id, | ||
content=_d.content, | ||
metadata=_d.metadata, | ||
embedding=embeddings[_count], | ||
) | ||
for _count, _d in enumerate(self.documents) | ||
if _d.content | ||
] | ||
|
||
self.documents = embedded_documents | ||
|
||
def add_documents(self, documents: List[Document]) -> None: | ||
self.documents.extend(documents) | ||
documents_text = [_d.content for _d in self.documents if _d.content] | ||
embeddings = self._embedder.fit_transform(documents_text) | ||
|
||
embedded_documents = [ | ||
Document( | ||
id=_d.id, | ||
content=_d.content, | ||
metadata=_d.metadata, | ||
embedding=embeddings[_count], | ||
) | ||
for _count, _d in enumerate(self.documents) | ||
if _d.content | ||
] | ||
|
||
self.documents = embedded_documents | ||
|
||
def get_document(self, id: str) -> Union[Document, None]: | ||
for document in self.documents: | ||
if document.id == id: | ||
return document | ||
return None | ||
|
||
def get_all_documents(self) -> List[Document]: | ||
return self.documents | ||
|
||
def delete_document(self, id: str) -> None: | ||
self.documents = [_d for _d in self.documents if _d.id != id] | ||
|
||
def update_document(self, id: str) -> None: | ||
raise NotImplementedError( | ||
"Update_document not implemented on MLMVectorStore class." | ||
) | ||
|
||
def retrieve(self, query: str, top_k: int = 5) -> List[Document]: | ||
query_vector = self._embedder.infer_vector(query) | ||
document_vectors = [_d.embedding for _d in self.documents if _d.content] | ||
distances = self._distance.distances(query_vector, document_vectors) | ||
|
||
# Get the indices of the top_k most similar documents | ||
top_k_indices = sorted(range(len(distances)), key=lambda i: distances[i])[ | ||
:top_k | ||
] | ||
|
||
return [self.documents[i] for i in top_k_indices] |
17 changes: 17 additions & 0 deletions
17
...mmunity/swarmauri_vectorstore_communitymlm/swarmauri_vectorstore_communitymlm/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .MlmEmbedding import MlmEmbedding | ||
from .MlmVectorStore import MlmVectorStore | ||
|
||
__version__ = "0.6.0.dev26" | ||
__long_desc__ = """ | ||
# Swarmauri Mlm Based Components | ||
Components Included: | ||
- MlmEmbedding | ||
- MlmVectorStore | ||
Visit us at: https://swarmauri.com | ||
Follow us at: https://github.com/swarmauri | ||
Star us at: https://github.com/swarmauri/swarmauri-sdk | ||
""" |
Oops, something went wrong.