Skip to content

Commit

Permalink
Adding context and ollama implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Getty committed Dec 29, 2024
1 parent d4298f1 commit 6cbe809
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 3 deletions.
6 changes: 6 additions & 0 deletions mcp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .registry import MCPRegistry
from .engine import MCPEngine
from .context import MCPContext

class MCPMain:
count = 0
Expand All @@ -20,6 +21,11 @@ def __init__(self,
def __str__(self):
return f"MCPMain instance {self.name}"

def context(self,
system_prompt: Optional[str] = None,
):
return MCPContext(system_prompt)

def model(self,
model: str,
**kwargs,
Expand Down
91 changes: 91 additions & 0 deletions mcp/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any, Dict, Optional, Union

import copy

class MCPMessage:

def __init__(self, content: str, role: str):
self.role = role
self.content = content

class MCPMessageSystem(MCPMessage):

def __init__(self, system_prompt: str):
super().__init__(system_prompt, role="system")

class MCPMessageAssistant(MCPMessage):

def __init__(self, assistant_context: str):
super().__init__(assistant_context, role="assistant")

class MCPMessageUser(MCPMessage):

def __init__(self, user_message: str, role: Optional[str] = None):
super().__init__(user_message, role if role is not None else "user")

class MCPContext:

def __init__(self, system_prompt: Optional[str] = None, user_query: Optional[str] = None):
self.messages = []
self.query = user_query
if not system_prompt is None:
self.messages.append(MCPMessageSystem(system_prompt))

def add(self, message: MCPMessage):
self.messages.append(message)
return self

def add_system(self, system_prompt: str):
self.messages.append(MCPMessageSystem(system_prompt))
return self

def add_assistant(self, assistant_context: str):
self.messages.append(MCPMessageAssistant(assistant_context))
return self

def add_user(self, user_message: str):
self.messages.append(MCPMessageUser(user_message))
return self

def add_query(self, user_query: str):
self.query = user_query
return self

def spawn(self):
return copy.deepcopy(self)

def spawn_with(self, message: Union[MCPMessage, 'MCPContext']):
clone = self.spawn()
if isinstance(message, MCPMessage):
clone.add(message)
elif isinstance(message, MCPContext):
for cmessage in message.messages:
clone.add(cmessage)
return clone

def spawn_with_system(self, system_prompt: str):
return self.spawn_with(MCPMessageSystem(system_prompt))

def spawn_with_assistant(self, assistant_context: str):
return self.spawn_with(MCPMessageAssistant(assistant_context))

def spawn_with_user(self, user_message: str):
return self.spawn_with(MCPMessageUser(user_message))

def spawn_with_query(self, user_query: str):
return self.spawn().add_query(user_query)

def all_messages(self):
messages = self.messages
if not self.query is None:
messages.append(MCPMessageUser(self.query))
return messages

def to_messages(self):
message_list = []
for message in self.all_messages():
message_list.append({
"content": message.content,
"role": message.role,
})
return message_list
54 changes: 54 additions & 0 deletions mcp/engine/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Union, Dict, List, Optional
import os
from urllib.parse import urlparse
import base64

from . import MCPEngine

class MCPEngineOllama(MCPEngine):

def __init__(self,
url: str = None,
**kwargs,
):
url = os.environ.get("OLLAMA_HOST")
userinfo = None
if os.environ.get("OLLAMA_PROXY_URL"):
if not url is None:
raise Exception("OLLAMA_PROXY_URL and OLLAMA_HOST set, please just use one")
else:
url = os.environ.get("OLLAMA_PROXY_URL")
if url:
parsed_url = urlparse(os.environ.get("OLLAMA_HOST"))
if parsed_url.scheme in ["http", "https"] and parsed_url.netloc:
if "@" in parsed_url.netloc:
userinfo = parsed_url.netloc.split("@")[0]
if parsed_url.port:
netloc = f"{parsed_url.hostname}:{parsed_url.port}"
else:
netloc = parsed_url.hostname
parsed_url = parsed_url._replace(netloc=netloc)
url = parsed_url.geturl()
elif parsed_url.path:
url = 'http://'+parsed_url.path+'/'
kwargs['host'] = url
if userinfo:
if not 'headers' in kwargs:
kwargs['headers'] = {}
auth_bytes = userinfo.encode("utf-8")
auth_base64 = base64.b64encode(auth_bytes).decode("utf-8")
kwargs['headers']['Authorization'] = 'Basic '+auth_base64
from ollama import Client as Ollama
self.client = Ollama(
**kwargs,
)

def get_models(self):
models = []
for model in self.client.list().models:
models.append(model.model)
models.sort()
return models

def __str__(self):
return f"MCP Engine Ollama {hex(id(self))}"
2 changes: 1 addition & 1 deletion mcp/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_engine_by_model(self, model: str):
def add_engine(self, name: str, engine: MCPEngine = None):
self._engines[name] = engine
self._update_models()
return self._engines[name]
return self

def __str__(self):
return f"MCP Registry {self.name}"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
dependencies = [
"pydantic>=2.0.0",
"typing-extensions>=4.0.0",
"pyyaml>=6.0.0"
"pyyaml>=6.0.0",
]
requires-python = ">=3.8"
readme = "README.md"
Expand Down
44 changes: 43 additions & 1 deletion tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mcp import MCP
from mcp.engine import MCPEngine
from mcp.model import MCPModel
from mcp.context import MCPContext, MCPMessage

class TestMCP(unittest.TestCase):

Expand Down Expand Up @@ -59,7 +60,7 @@ def test_002_anthropic(self):
warnings.warn("Can't test Anthropic engine without ANTHROPIC_API_KEY", UserWarning)

def test_003_groq(self):
"""Test Anthropic"""
"""Test Groq"""
if os.environ.get("GROQ_API_KEY"):
engine = MCP.engine('groq')
self.assertIsInstance(engine, MCPEngine)
Expand All @@ -80,9 +81,50 @@ def test_003_groq(self):
else:
warnings.warn("Can't test Groq engine without GROQ_API_KEY", UserWarning)

def test_004_ollama(self):
"""Test Ollama"""
if os.environ.get("OLLAMA_HOST") or os.environ.get("OLLAMA_PROXY_URL"):
engine = MCP.engine('ollama')
self.assertIsInstance(engine, MCPEngine)
self.assertEqual(type(engine).__name__, "MCPEngineOllama")
models = engine.get_models()
self.assertTrue(len(models) > 0)
else:
warnings.warn("Can't test Ollama engine without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning)

def test_010_registry(self):
"""Test registry"""
self.assertEqual(MCP.count, 1)

def test_020_context(self):
"""Test context"""
nosys_context = MCP.context()
self.assertIsInstance(nosys_context, MCPContext)
self.assertTrue(len(nosys_context.messages) == 0)
nosys_context.add_system("you are an assistant")
self.assertTrue(len(nosys_context.messages) == 1)
self.assertEqual(nosys_context.to_messages(), [{
'content': 'you are an assistant',
'role': 'system',
}])

sys_context = MCP.context("you are an assistant that hates his work")
self.assertIsInstance(sys_context, MCPContext)
self.assertTrue(len(sys_context.messages) == 1)
sys_context.add_assistant("roger rabbit is a fictional animated anthropomorphic rabbit")
self.assertTrue(len(sys_context.messages) == 2)
sys_context.add_user("who is roger rabbit?")
self.assertTrue(len(sys_context.messages) == 3)
self.assertEqual(sys_context.to_messages(), [{
'content': 'you are an assistant that hates his work',
'role': 'system',
}, {
'content': 'roger rabbit is a fictional animated anthropomorphic rabbit',
'role': 'assistant',
}, {
'content': 'who is roger rabbit?',
'role': 'user',
}])

if __name__ == "__main__":
unittest.main()

0 comments on commit 6cbe809

Please sign in to comment.