Skip to content

Commit

Permalink
Feature/add back ollama provider (#1522)
Browse files Browse the repository at this point in the history
* add extra parsers

* add back ollama

* rvert auth workflow
  • Loading branch information
emrgnt-cmplxty authored Oct 29, 2024
1 parent 0480d2e commit 939f7d4
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 72 deletions.
2 changes: 1 addition & 1 deletion docs/api-reference/openapi.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
# Embeddings
"LiteLLMEmbeddingProvider",
"OpenAIEmbeddingProvider",
"OllamaEmbeddingProvider",
# LLM
"OpenAICompletionProvider",
"LiteLLMCompletionProvider",
Expand Down
23 changes: 12 additions & 11 deletions py/core/base/providers/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@ def supported_providers(self) -> list[str]:
] # Could add more providers like AWS SES, SendGrid etc.

def validate_config(self) -> None:
if self.provider == "smtp":
if not all(
[
self.smtp_server,
self.smtp_port,
self.smtp_username,
self.smtp_password,
self.from_email,
]
):
raise ValueError("SMTP configuration is incomplete")
pass
# if self.provider == "smtp":
# if not all(
# [
# self.smtp_server,
# self.smtp_port,
# self.smtp_username,
# self.smtp_password,
# self.from_email,
# ]
# ):
# raise ValueError("SMTP configuration is incomplete")


logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions py/core/base/providers/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import random
import time
from abc import abstractmethod
from enum import Enum
Expand All @@ -14,7 +15,6 @@
VectorSearchResult,
default_embedding_prefixes,
)
import random
from .base import Provider, ProviderConfig

logger = logging.getLogger()
Expand Down Expand Up @@ -44,7 +44,7 @@ def validate_config(self) -> None:

@property
def supported_providers(self) -> list[str]:
return ["litellm", "openai"]
return ["litellm", "openai", "ollama"]


class EmbeddingProvider(Provider):
Expand Down
2 changes: 1 addition & 1 deletion py/core/examples/scripts/run_auth_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
) # Replace with your R2R deployment URL

# Register a new user
user_result = client.register("[email protected]", "password123")
user_result = client.register("[email protected]", "password123")
print(user_result)

# # Uncomment when running with authentication
Expand Down
7 changes: 6 additions & 1 deletion py/core/main/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HatchetOrchestrationProvider,
LiteLLMCompletionProvider,
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAICompletionProvider,
OpenAIEmbeddingProvider,
PostgresDBProvider,
Expand All @@ -27,7 +28,11 @@ class R2RProviders(BaseModel):
auth: Union[R2RAuthProvider, SupabaseAuthProvider]
database: PostgresDBProvider
ingestion: Union[R2RIngestionProvider, UnstructuredIngestionProvider]
embedding: Union[LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider]
embedding: Union[
LiteLLMEmbeddingProvider,
OpenAIEmbeddingProvider,
OllamaEmbeddingProvider,
]
llm: Union[LiteLLMCompletionProvider, OpenAICompletionProvider]
orchestration: Union[
HatchetOrchestrationProvider, SimpleOrchestrationProvider
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def create_graph(
description="Settings for the graph creation process.",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedKGCreationResponse: # type: ignore
): # -> WrappedKGCreationResponse: # type: ignore
"""
Creating a graph on your documents. This endpoint takes input a list of document ids and KGCreationSettings.
If document IDs are not provided, the graph will be created on all documents in the system.
Expand Down Expand Up @@ -170,7 +170,7 @@ async def enrich_graph(
description="Settings for the graph enrichment process.",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedKGEnrichmentResponse:
): # -> WrappedKGEnrichmentResponse:
"""
This endpoint enriches the graph with additional information.
It creates communities of nodes based on their similarity and adds embeddings to the graph.
Expand Down
18 changes: 16 additions & 2 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
HatchetOrchestrationProvider,
LiteLLMCompletionProvider,
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAICompletionProvider,
OpenAIEmbeddingProvider,
PostgresDBProvider,
Expand Down Expand Up @@ -181,7 +182,11 @@ async def create_database_provider(
@staticmethod
def create_embedding_provider(
embedding: EmbeddingConfig, *args, **kwargs
) -> Union[LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider]:
) -> Union[
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAIEmbeddingProvider,
]:
embedding_provider: Optional[EmbeddingProvider] = None

if embedding.provider == "openai":
Expand All @@ -198,6 +203,11 @@ def create_embedding_provider(

embedding_provider = LiteLLMEmbeddingProvider(embedding)

elif embedding.provider == "ollama":
from core.providers import OllamaEmbeddingProvider

embedding_provider = OllamaEmbeddingProvider(embedding)

else:
raise ValueError(
f"Embedding provider {embedding.provider} not supported"
Expand Down Expand Up @@ -252,7 +262,11 @@ async def create_providers(
Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]
] = None,
embedding_provider_override: Optional[
Union[LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider]
Union[
LiteLLMEmbeddingProvider,
OpenAIEmbeddingProvider,
OllamaEmbeddingProvider,
]
] = None,
ingestion_provider_override: Optional[
Union[R2RIngestionProvider, UnstructuredIngestionProvider]
Expand Down
5 changes: 4 additions & 1 deletion py/core/pipes/kg/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from core.providers import (
LiteLLMCompletionProvider,
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAICompletionProvider,
OpenAIEmbeddingProvider,
PostgresDBProvider,
Expand All @@ -26,7 +27,9 @@ def __init__(
OpenAICompletionProvider, LiteLLMCompletionProvider
],
embedding_provider: Union[
LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider
LiteLLMEmbeddingProvider,
OpenAIEmbeddingProvider,
OllamaEmbeddingProvider,
],
logging_provider: SqlitePersistentLoggingProvider,
**kwargs,
Expand Down
5 changes: 4 additions & 1 deletion py/core/pipes/kg/deduplication_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from core.providers import (
LiteLLMCompletionProvider,
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAICompletionProvider,
OpenAIEmbeddingProvider,
PostgresDBProvider,
Expand All @@ -30,7 +31,9 @@ def __init__(
LiteLLMCompletionProvider, OpenAICompletionProvider
],
embedding_provider: Union[
LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider
LiteLLMEmbeddingProvider,
OpenAIEmbeddingProvider,
OllamaEmbeddingProvider,
],
config: AsyncPipe.PipeConfig,
logging_provider: SqlitePersistentLoggingProvider,
Expand Down
7 changes: 6 additions & 1 deletion py/core/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from .crypto import BCryptConfig, BCryptProvider
from .database import PostgresDBProvider
from .email import AsyncSMTPEmailProvider, ConsoleMockEmailProvider
from .embeddings import LiteLLMEmbeddingProvider, OpenAIEmbeddingProvider
from .embeddings import (
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAIEmbeddingProvider,
)
from .ingestion import ( # type: ignore
R2RIngestionConfig,
R2RIngestionProvider,
Expand Down Expand Up @@ -32,6 +36,7 @@
"PostgresDBProvider",
# Embeddings
"LiteLLMEmbeddingProvider",
"OllamaEmbeddingProvider",
"OpenAIEmbeddingProvider",
# Email
"AsyncSMTPEmailProvider",
Expand Down
124 changes: 77 additions & 47 deletions py/core/providers/email/smtp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import asyncio
import logging
import os
from abc import ABC, abstractmethod
import smtplib
import ssl
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional

from aiosmtplib import SMTP

from core.base import EmailConfig, EmailProvider

logger = logging.getLogger()
logger = logging.getLogger(__name__)


class AsyncSMTPEmailProvider(EmailProvider):
"""Email provider implementation using Brevo SMTP relay"""

def __init__(self, config: EmailConfig):
super().__init__(config)
self.smtp_server = config.smtp_server or os.getenv("R2R_SMTP_SERVER")
Expand All @@ -35,16 +37,36 @@ def __init__(self, config: EmailConfig):
if not self.smtp_password:
raise ValueError("SMTP password is required")

self.from_email: Optional[str] = config.from_email or os.getenv(
"R2R_FROM_EMAIL"
self.from_email: Optional[str] = (
config.from_email
or os.getenv("R2R_FROM_EMAIL")
or self.smtp_username
)
if not self.from_email:
raise ValueError("From email is required")
self.ssl_context = ssl.create_default_context()

async def _send_email_sync(self, msg: MIMEMultipart) -> None:
"""Synchronous email sending wrapped in asyncio executor"""
loop = asyncio.get_running_loop()

def _send():
with smtplib.SMTP_SSL(
self.smtp_server,
self.smtp_port,
context=self.ssl_context,
timeout=30,
) as server:
logger.info("Connected to SMTP server")
server.login(self.smtp_username, self.smtp_password)
logger.info("Login successful")
server.send_message(msg)
logger.info("Message sent successfully!")

self.use_tls = (
config.use_tls
or os.getenv("R2R_SMTP_USE_TLS", "true").lower() == "true"
)
try:
await loop.run_in_executor(None, _send)
except Exception as e:
error_msg = f"Failed to send email: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e

async def send_email(
self,
Expand All @@ -54,7 +76,7 @@ async def send_email(
html_body: Optional[str] = None,
) -> None:
msg = MIMEMultipart("alternative")
msg["Subject"] = subject # type: ignore
msg["Subject"] = subject
msg["From"] = self.from_email # type: ignore
msg["To"] = to_email

Expand All @@ -63,58 +85,66 @@ async def send_email(
msg.attach(MIMEText(html_body, "html"))

try:
smtp = SMTP(
hostname=self.smtp_server,
port=int(self.smtp_port) if self.smtp_port else None,
use_tls=self.use_tls,
)

await smtp.connect()
if self.smtp_username and self.smtp_password:
await smtp.login(self.smtp_username, self.smtp_password)

await smtp.send_message(msg)
await smtp.quit()

logger.info("Initializing SMTP connection...")
async with asyncio.timeout(30): # Overall timeout
await self._send_email_sync(msg)
except asyncio.TimeoutError:
error_msg = "Operation timed out while trying to send email"
logger.error(error_msg)
raise RuntimeError(error_msg)
except Exception as e:
logger.error(f"Failed to send email: {str(e)}")
raise
error_msg = f"Failed to send email: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e

async def send_verification_email(
self, to_email: str, verification_code: str
) -> None:
subject = "Verify Your Email Address"
body = f"""
Thank you for registering! Please verify your email address by entering the following code:
Please verify your email address by entering the following code:
{verification_code}
Verification code: {verification_code}
This code will expire in 24 hours.
If you did not request this verification, please ignore this email.
"""

html_body = f"""
<h2>Email Verification</h2>
<p>Thank you for registering! Please verify your email address by entering the following code:</p>
<h3>{verification_code}</h3>
<p>This code will expire in 24 hours.</p>
<p>Please verify your email address by entering the following code:</p>
<p style="font-size: 24px; font-weight: bold; margin: 20px 0;">
Verification code: {verification_code}
</p>
<p>If you did not request this verification, please ignore this email.</p>
"""
await self.send_email(to_email, subject, body, html_body)

await self.send_email(
to_email=to_email,
subject="Please verify your email address",
body=body,
html_body=html_body,
)

async def send_password_reset_email(
self, to_email: str, reset_token: str
) -> None:
subject = "Password Reset Request"
body = f"""
We received a request to reset your password. Use the following code to reset your password:
You have requested to reset your password.
{reset_token}
Reset token: {reset_token}
This code will expire in 1 hour. If you didn't request this reset, please ignore this email.
If you did not request a password reset, please ignore this email.
"""

html_body = f"""
<h2>Password Reset Request</h2>
<p>We received a request to reset your password. Use the following code to reset your password:</p>
<h3>{reset_token}</h3>
<p>This code will expire in 1 hour.</p>
<p>If you didn't request this reset, please ignore this email.</p>
<p>You have requested to reset your password.</p>
<p style="font-size: 24px; font-weight: bold; margin: 20px 0;">
Reset token: {reset_token}
</p>
<p>If you did not request a password reset, please ignore this email.</p>
"""
await self.send_email(to_email, subject, body, html_body)

await self.send_email(
to_email=to_email,
subject="Password Reset Request",
body=body,
html_body=html_body,
)
Loading

0 comments on commit 939f7d4

Please sign in to comment.