Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image support for Ollama #14713

Merged
merged 15 commits into from
Dec 16, 2023
354 changes: 100 additions & 254 deletions docs/docs/integrations/chat/ollama.ipynb

Large diffs are not rendered by default.

362 changes: 70 additions & 292 deletions docs/docs/integrations/llms/ollama.ipynb

Large diffs are not rendered by default.

150 changes: 143 additions & 7 deletions libs/community/langchain_community/chat_models/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union

from langchain_core._api import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
Expand All @@ -15,9 +16,10 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from langchain_community.llms.ollama import _OllamaCommon
from langchain_community.llms.ollama import OllamaEndpointNotFoundError, _OllamaCommon


@deprecated("0.0.3", alternative="_chat_stream_response_to_chat_generation_chunk")
def _stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
Expand All @@ -30,6 +32,20 @@ def _stream_response_to_chat_generation_chunk(
)


def _chat_stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
"""Convert a stream response to a generation chunk."""
parsed_response = json.loads(stream_response)
generation_info = parsed_response if parsed_response.get("done") is True else None
return ChatGenerationChunk(
message=AIMessageChunk(
content=parsed_response.get("message", {}).get("content", "")
),
generation_info=generation_info,
)


class ChatOllama(BaseChatModel, _OllamaCommon):
"""Ollama locally runs large language models.

Expand All @@ -52,11 +68,15 @@ def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return False

@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
if message.content[0].get("type") == "text":
message_text = f"[INST] {message.content[0]['text']} [/INST]"
elif message.content[0].get("type") == "image_url":
message_text = message.content[0]["image_url"]["url"]
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
Expand All @@ -70,6 +90,98 @@ def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
[self._format_message_as_text(message) for message in messages]
)

def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
else:
raise ValueError("Received unsupported message type for Ollama.")

content = ""
images = []
if isinstance(message.content, str):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
else:
raise ValueError(
"Only string image_url " "content parts are supported."
)
else:
raise ValueError(
"Unsupported message content type. "
"Must either have type 'text' or type 'image_url' "
"with a string 'image_url' field."
)

ollama_messages.append(
{
"role": role,
"content": content,
"images": images,
}
)

return ollama_messages

def _create_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {
"messages": self._convert_messages_to_ollama_messages(messages),
}
yield from self._create_stream(
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
)

def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk: Optional[ChatGenerationChunk] = None
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")

return final_chunk

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -94,9 +206,12 @@ def _generate(
])
"""

prompt = self._format_messages_as_text(messages)
final_chunk = super()._stream_with_aggregation(
prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs
final_chunk = self._chat_stream_with_aggregation(
messages,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text),
Expand All @@ -110,9 +225,30 @@ def _stream(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
try:
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
yield from self._legacy_stream(messages, stop, **kwargs)

@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
Expand Down
54 changes: 45 additions & 9 deletions libs/community/langchain_community/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def _stream_response_to_generation_chunk(
)


class OllamaEndpointNotFoundError(Exception):
"""Raised when the Ollama endpoint is not found."""


class _OllamaCommon(BaseLanguageModel):
base_url: str = "http://localhost:11434"
"""Base url the model is hosted under."""
Expand Down Expand Up @@ -129,10 +133,26 @@ def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model, "format": self.format}, **self._default_params}

def _create_stream(
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {"prompt": prompt, "images": images}
yield from self._create_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
)

def _create_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if self.stop is not None and stop is not None:
Expand All @@ -156,20 +176,34 @@ def _create_stream(
**kwargs,
}

if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}

response = requests.post(
url=f"{self.base_url}/api/generate/",
url=api_url,
headers={"Content-Type": "application/json"},
json={"prompt": prompt, **params},
json=request_payload,
stream=True,
timeout=self.timeout,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
if response.status_code == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
return response.iter_lines(decode_unicode=True)

def _stream_with_aggregation(
Expand All @@ -181,7 +215,7 @@ def _stream_with_aggregation(
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
Expand Down Expand Up @@ -225,6 +259,7 @@ def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
Expand All @@ -248,6 +283,7 @@ def _generate(
final_chunk = super()._stream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
Expand Down
Loading