Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for claude #38

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
137 changes: 89 additions & 48 deletions lib/instructor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ defmodule Instructor do
{:ok, Ecto.Schema.t()}
| {:error, Ecto.Changeset.t()}
| {:error, String.t()}
| {:error, any()}
| Stream.t()
def chat_completion(params, config \\ nil) do
params =
Expand Down Expand Up @@ -270,7 +271,7 @@ defmodule Instructor do
params = Keyword.put(params, :response_model, wrapped_model)
validation_context = Keyword.get(params, :validation_context, %{})
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, wrapped_model, params)
params = params_for_mode(mode, wrapped_model, params, adapter(config))

model =
if is_ecto_schema(response_model) do
Expand Down Expand Up @@ -341,7 +342,7 @@ defmodule Instructor do
params = Keyword.put(params, :response_model, wrapped_model)
validation_context = Keyword.get(params, :validation_context, %{})
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, wrapped_model, params)
params = params_for_mode(mode, wrapped_model, params, adapter(config))

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
Expand Down Expand Up @@ -389,7 +390,7 @@ defmodule Instructor do
params = Keyword.put(params, :response_model, wrapped_model)
validation_context = Keyword.get(params, :validation_context, %{})
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, wrapped_model, params)
params = params_for_mode(mode, wrapped_model, params, adapter(config))

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
Expand Down Expand Up @@ -419,7 +420,7 @@ defmodule Instructor do
validation_context = Keyword.get(params, :validation_context, %{})
max_retries = Keyword.get(params, :max_retries)
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, response_model, params)
params = params_for_mode(mode, response_model, params, adapter(config))

model =
if is_ecto_schema(response_model) do
Expand All @@ -436,10 +437,16 @@ defmodule Instructor do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
else
{:llm, {:error, error}} ->
{:error, "LLM Adapter Error: #{inspect(error)}"}
{:error, {:adapter_error, error}}

{:valid_json, {:error, error}} ->
{:error, "Invalid JSON returned from LLM: #{inspect(error)}"}
# pass the error as it is to the user consuming API
# one complex use case is
# -> you might want to reformat the json data from a different model via another API call.
# So, a smaller model like claude-haiku for subsequent LLM call
# https://github.com/thmsmlr/instructor_ex/pull/55/files
Logger.error(error: "Invalid JSON returned from LLM: #{inspect(error)}")
{:error, {:invalid_json, error}}

{:validation, changeset, response} ->
if max_retries > 0 do
Expand Down Expand Up @@ -533,34 +540,26 @@ defmodule Instructor do
]
end

defp params_for_mode(mode, response_model, params) do
defp echo_response(_), do: []

defp params_for_mode(mode, response_model, params, adapter) do
json_schema = JSONSchema.from_ecto_schema(response_model)

params =
params
|> Keyword.update(:messages, [], fn messages ->
decoded_json_schema = Jason.decode!(json_schema)
messages =
case adapter do
Instructor.Adapters.Anthropic ->
messages

additional_definitions =
if defs = decoded_json_schema["$defs"] do
"\nHere are some more definitions to adhere too:\n" <> Jason.encode!(defs)
else
""
_ ->
[sys_message(json_schema) | messages]
end

sys_message = %{
role: "system",
content: """
As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema:\n
#{json_schema}

#{additional_definitions}
"""
}

case mode do
:md_json ->
[sys_message | messages] ++
[sys_message(json_schema) | messages] ++
[
%{
role: "assistant",
Expand All @@ -569,43 +568,77 @@ defmodule Instructor do
]

:json ->
[sys_message | messages]
[sys_message(json_schema) | messages]

:tools ->
messages
end
end)

case mode do
:md_json ->
params |> Keyword.put(:stop, "```")
params =
case mode do
:md_json ->
params

# |> Keyword.put(:stop, "```")

:json ->
params
|> Keyword.put(:response_format, %{
type: "json_object"
})

:tools ->
params
|> Keyword.put(:tools, [
%{
type: "function",
function: %{
"description" =>
"Correctly extracted `Schema` with all the required parameters with correct types",
"name" => "Schema",
"parameters" => json_schema |> Jason.decode!()
}
}
])
|> Keyword.put(:tool_choice, %{
type: "function",
function: %{name: "Schema"}
})
end

:json ->
case adapter do
Instructor.Adapters.Anthropic ->
params
|> Keyword.put(:response_format, %{
type: "json_object"
})
|> Keyword.put(:system, sys_message(json_schema).content)
|> Keyword.put_new(:max_tokens, 1800)

:tools ->
_ ->
params
|> Keyword.put(:tools, [
%{
type: "function",
function: %{
"description" =>
"Correctly extracted `Schema` with all the required parameters with correct types",
"name" => "Schema",
"parameters" => json_schema |> Jason.decode!()
}
}
])
|> Keyword.put(:tool_choice, %{
type: "function",
function: %{name: "Schema"}
})
end
end

defp sys_message(json_schema) do
decoded_json_schema = Jason.decode!(json_schema)

additional_definitions =
if defs = decoded_json_schema["$defs"] do
"\nHere are some more definitions to adhere too:\n" <> Jason.encode!(defs)
else
""
end

%{
role: "system",
content: """
As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema:\n
#{json_schema}

#{additional_definitions}
"""
}
end

defp call_validate(response_model, changeset, context) do
cond do
not is_ecto_schema(response_model) ->
Expand All @@ -622,6 +655,14 @@ defmodule Instructor do
end
end

defp adapter(config) when is_list(config) do
Keyword.get(
config,
:adapter,
Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
)
end

defp adapter(%{adapter: adapter}) when is_atom(adapter), do: adapter
defp adapter(_), do: Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
end
Loading
Loading