Skip to content

Commit

Permalink
Add image support for Ollama (#14713)
Browse files Browse the repository at this point in the history
Support [LLaVA](https://ollama.ai/library/llava):
* Upgrade Ollama
* `ollama pull llava`

Ensure compatibility with [image prompt
template](#14263)

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
rlancemartin and jacoblee93 authored Dec 16, 2023
1 parent 1075e7d commit 4242186
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 562 deletions.
354 changes: 100 additions & 254 deletions docs/docs/integrations/chat/ollama.ipynb

Large diffs are not rendered by default.

362 changes: 70 additions & 292 deletions docs/docs/integrations/llms/ollama.ipynb

Large diffs are not rendered by default.

150 changes: 143 additions & 7 deletions libs/community/langchain_community/chat_models/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union

from langchain_core._api import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
Expand All @@ -15,9 +16,10 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from langchain_community.llms.ollama import _OllamaCommon
from langchain_community.llms.ollama import OllamaEndpointNotFoundError, _OllamaCommon


@deprecated("0.0.3", alternative="_chat_stream_response_to_chat_generation_chunk")
def _stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
Expand All @@ -30,6 +32,20 @@ def _stream_response_to_chat_generation_chunk(
)


def _chat_stream_response_to_chat_generation_chunk(
stream_response: str,
) -> ChatGenerationChunk:
"""Convert a stream response to a generation chunk."""
parsed_response = json.loads(stream_response)
generation_info = parsed_response if parsed_response.get("done") is True else None
return ChatGenerationChunk(
message=AIMessageChunk(
content=parsed_response.get("message", {}).get("content", "")
),
generation_info=generation_info,
)


class ChatOllama(BaseChatModel, _OllamaCommon):
"""Ollama locally runs large language models.
Expand All @@ -52,11 +68,15 @@ def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return False

@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
if message.content[0].get("type") == "text":
message_text = f"[INST] {message.content[0]['text']} [/INST]"
elif message.content[0].get("type") == "image_url":
message_text = message.content[0]["image_url"]["url"]
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
Expand All @@ -70,6 +90,98 @@ def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
[self._format_message_as_text(message) for message in messages]
)

def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
else:
raise ValueError("Received unsupported message type for Ollama.")

content = ""
images = []
if isinstance(message.content, str):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
else:
raise ValueError(
"Only string image_url " "content parts are supported."
)
else:
raise ValueError(
"Unsupported message content type. "
"Must either have type 'text' or type 'image_url' "
"with a string 'image_url' field."
)

ollama_messages.append(
{
"role": role,
"content": content,
"images": images,
}
)

return ollama_messages

def _create_chat_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {
"messages": self._convert_messages_to_ollama_messages(messages),
}
yield from self._create_stream(
payload=payload, stop=stop, api_url=f"{self.base_url}/api/chat/", **kwargs
)

def _chat_stream_with_aggregation(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
**kwargs: Any,
) -> ChatGenerationChunk:
final_chunk: Optional[ChatGenerationChunk] = None
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("No data received from Ollama stream.")

return final_chunk

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -94,9 +206,12 @@ def _generate(
])
"""

prompt = self._format_messages_as_text(messages)
final_chunk = super()._stream_with_aggregation(
prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs
final_chunk = self._chat_stream_with_aggregation(
messages,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
chat_generation = ChatGeneration(
message=AIMessage(content=final_chunk.text),
Expand All @@ -110,9 +225,30 @@ def _stream(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
try:
for stream_resp in self._create_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
yield from self._legacy_stream(messages, stop, **kwargs)

@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
Expand Down
54 changes: 45 additions & 9 deletions libs/community/langchain_community/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def _stream_response_to_generation_chunk(
)


class OllamaEndpointNotFoundError(Exception):
"""Raised when the Ollama endpoint is not found."""


class _OllamaCommon(BaseLanguageModel):
base_url: str = "http://localhost:11434"
"""Base url the model is hosted under."""
Expand Down Expand Up @@ -129,10 +133,26 @@ def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model, "format": self.format}, **self._default_params}

def _create_stream(
def _create_generate_stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
payload = {"prompt": prompt, "images": images}
yield from self._create_stream(
payload=payload,
stop=stop,
api_url=f"{self.base_url}/api/generate/",
**kwargs,
)

def _create_stream(
self,
api_url: str,
payload: Any,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if self.stop is not None and stop is not None:
Expand All @@ -156,20 +176,34 @@ def _create_stream(
**kwargs,
}

if payload.get("messages"):
request_payload = {"messages": payload.get("messages", []), **params}
else:
request_payload = {
"prompt": payload.get("prompt"),
"images": payload.get("images", []),
**params,
}

response = requests.post(
url=f"{self.base_url}/api/generate/",
url=api_url,
headers={"Content-Type": "application/json"},
json={"prompt": prompt, **params},
json=request_payload,
stream=True,
timeout=self.timeout,
)
response.encoding = "utf-8"
if response.status_code != 200:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
if response.status_code == 404:
raise OllamaEndpointNotFoundError(
"Ollama call failed with status code 404."
)
else:
optional_detail = response.json().get("error")
raise ValueError(
f"Ollama call failed with status code {response.status_code}."
f" Details: {optional_detail}"
)
return response.iter_lines(decode_unicode=True)

def _stream_with_aggregation(
Expand All @@ -181,7 +215,7 @@ def _stream_with_aggregation(
**kwargs: Any,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for stream_resp in self._create_stream(prompt, stop, **kwargs):
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if stream_resp:
chunk = _stream_response_to_generation_chunk(stream_resp)
if final_chunk is None:
Expand Down Expand Up @@ -225,6 +259,7 @@ def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
images: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
Expand All @@ -248,6 +283,7 @@ def _generate(
final_chunk = super()._stream_with_aggregation(
prompt,
stop=stop,
images=images,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
Expand Down

0 comments on commit 4242186

Please sign in to comment.