Skip to content

Commit 84a9968

Browse files
committed
Fix some broken Gemini tests, make SSEStreamParsing more resilient
1 parent 6d7fa80 commit 84a9968

File tree

7 files changed

+127
-63
lines changed

7 files changed

+127
-63
lines changed

lib/instructor/adapters/gemini.ex

+68-29
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ defmodule Instructor.Adapters.Gemini do
2323
"""
2424

2525
@behaviour Instructor.Adapter
26+
alias Instructor.SSEStreamParser
2627
alias Instructor.Adapters
2728
alias Instructor.JSONSchema
2829

29-
@supported_modes [:tools, :json_schema]
30+
@supported_modes [:json_schema]
3031

3132
@doc """
3233
Run a completion against Google's Gemini API
@@ -102,7 +103,7 @@ defmodule Instructor.Adapters.Gemini do
102103
generation_config =
103104
generation_config
104105
|> Map.put("response_mime_type", "application/json")
105-
|> Map.put("response_schema", map_schema(schema))
106+
|> Map.put("response_schema", normalize_json_schema(schema))
106107

107108
params
108109
|> Map.put(:generationConfig, generation_config)
@@ -116,7 +117,7 @@ defmodule Instructor.Adapters.Gemini do
116117
%{
117118
name: tool["name"],
118119
description: tool["description"],
119-
parameters: map_schema(tool["parameters"])
120+
parameters: normalize_json_schema(tool["parameters"])
120121
}
121122
end)
122123
}
@@ -155,23 +156,7 @@ defmodule Instructor.Adapters.Gemini do
155156
json: params,
156157
rpc_function: :streamGenerateContent,
157158
into: fn {:data, data}, {req, resp} ->
158-
chunks =
159-
data
160-
|> String.split("\n")
161-
|> Enum.filter(fn line ->
162-
String.starts_with?(line, "data: {")
163-
end)
164-
|> Enum.map(fn line ->
165-
line
166-
|> String.replace_prefix("data: ", "")
167-
|> Jason.decode!()
168-
|> then(&parse_stream_chunk_for_mode(mode, &1))
169-
end)
170-
171-
for chunk <- chunks do
172-
send(pid, chunk)
173-
end
174-
159+
send(pid, data)
175160
{:cont, {req, resp}}
176161
end
177162
)
@@ -196,6 +181,8 @@ defmodule Instructor.Adapters.Gemini do
196181
end,
197182
fn task -> Task.await(task) end
198183
)
184+
|> SSEStreamParser.parse()
185+
|> Stream.map(fn chunk -> parse_stream_chunk_for_mode(mode, chunk) end)
199186
end
200187

201188
defp do_chat_completion(mode, params, config) do
@@ -266,18 +253,70 @@ defmodule Instructor.Adapters.Gemini do
266253
chunk
267254
end
268255

269-
defp map_schema(schema) do
270-
JSONSchema.traverse_and_update(schema, fn
271-
%{"type" => _} = x
272-
when is_map_key(x, "format") or is_map_key(x, "pattern") or
273-
is_map_key(x, "title") or is_map_key(x, "additionalProperties") ->
274-
Map.drop(x, ["format", "pattern", "title", "additionalProperties"])
256+
defp normalize_json_schema(schema) do
257+
JSONSchema.traverse_and_update(
258+
schema,
259+
fn
260+
{%{"type" => _} = x, path}
261+
when is_map_key(x, "format") or is_map_key(x, "pattern") or
262+
is_map_key(x, "title") or is_map_key(x, "additionalProperties") ->
263+
x
264+
|> Map.drop(["format", "pattern", "title", "additionalProperties"])
265+
|> case do
266+
%{"type" => "object", "properties" => properties} when map_size(properties) == 0 ->
267+
raise """
268+
Invalid JSON Schema: object with no properties at path: #{inspect(path)}
269+
270+
Gemini does not support empty objects. This is likely because have have a naked :map type
271+
without any fields at #{inspect(path)}. Try switching to an embedded schema instead.
272+
"""
273+
274+
x ->
275+
x
276+
end
275277

276-
x ->
277-
x
278-
end)
278+
{x, _path} ->
279+
x
280+
end,
281+
include_path: true
282+
)
283+
|> inline_defs()
284+
end
285+
286+
defp inline_defs(schema) do
287+
# First extract the definitions map for reference
288+
{defs, schema} = Map.pop(schema, "$defs", %{})
289+
290+
# Traverse and replace all $refs with their definitions
291+
traverse_and_inline(schema, defs)
292+
end
293+
294+
defp traverse_and_inline(schema, defs) when is_map(schema) do
295+
cond do
296+
# If we find a $ref, replace it with the inlined definition
297+
Map.has_key?(schema, "$ref") ->
298+
ref = schema["$ref"]
299+
def_key = String.replace_prefix(ref, "#/$defs/", "")
300+
definition = Map.get(defs, def_key, %{})
301+
# Recursively inline any nested refs in the definition
302+
traverse_and_inline(definition, defs)
303+
304+
# Otherwise traverse all values in the map
305+
true ->
306+
schema
307+
|> Enum.map(fn {k, v} -> {k, traverse_and_inline(v, defs)} end)
308+
|> Enum.into(%{})
309+
end
279310
end
280311

312+
# Handle arrays by traversing each element
313+
defp traverse_and_inline(schema, defs) when is_list(schema) do
314+
Enum.map(schema, &traverse_and_inline(&1, defs))
315+
end
316+
317+
# Base case - return non-map/list values as is
318+
defp traverse_and_inline(schema, _defs), do: schema
319+
281320
defp snake_to_camel(snake_case_string) do
282321
snake_case_string
283322
|> String.split("_")

lib/instructor/json_schema.ex

+27-9
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ defmodule Instructor.JSONSchema do
8383
end
8484

8585
defp uses_use_instructor(ecto_schema) when is_ecto_schema(ecto_schema) do
86-
function_exported?(ecto_schema, :__llm_doc__, 0)
86+
{:__llm_doc__, 0} in ecto_schema.__info__(:functions)
8787
end
8888

8989
defp uses_use_instructor(_), do: false
@@ -414,28 +414,46 @@ defmodule Instructor.JSONSchema do
414414
415415
## Parameters
416416
- tree: The tree structure to traverse (can be a map, list, or any other type)
417-
- fun: A function that takes an element and returns either:
417+
- fun: A function that takes either:
418+
- Just the element if include_path: false (default)
419+
- A tuple of {element, path} if include_path: true, where path is a list of keys to reach this element
420+
The function should return either:
418421
- An updated element
419422
- nil to remove the element
420423
- The original element if no changes are needed
424+
- opts: Optional keyword list of options
425+
- include_path: boolean, when true includes the path to each element in the callback (default: false)
421426
422427
## Returns
423428
The updated tree structure
424429
"""
425-
def traverse_and_update(tree, fun) when is_map(tree) do
430+
def traverse_and_update(tree, fun, opts \\ []) do
431+
do_traverse_and_update(tree, fun, [], opts)
432+
end
433+
434+
defp do_traverse_and_update(tree, fun, path, opts) when is_map(tree) do
426435
tree
427-
|> Enum.map(fn {k, v} -> {k, traverse_and_update(v, fun)} end)
436+
|> Enum.map(fn {k, v} -> {k, do_traverse_and_update(v, fun, path ++ [k], opts)} end)
428437
|> Enum.filter(fn {_, v} -> v != nil end)
429438
|> Enum.into(%{})
430-
|> fun.()
439+
|> maybe_call_with_path(fun, path, opts)
431440
end
432441

433-
def traverse_and_update(tree, fun) when is_list(tree) do
442+
defp do_traverse_and_update(tree, fun, path, opts) when is_list(tree) do
434443
tree
435-
|> Enum.map(fn elem -> traverse_and_update(elem, fun) end)
444+
|> Enum.with_index()
445+
|> Enum.map(fn {elem, idx} -> do_traverse_and_update(elem, fun, path ++ [idx], opts) end)
436446
|> Enum.filter(&(&1 != nil))
437-
|> fun.()
447+
|> maybe_call_with_path(fun, path, opts)
438448
end
439449

440-
def traverse_and_update(tree, fun), do: fun.(tree)
450+
defp do_traverse_and_update(tree, fun, path, opts), do: maybe_call_with_path(tree, fun, path, opts)
451+
452+
defp maybe_call_with_path(value, fun, path, opts) do
453+
if Keyword.get(opts, :include_path, false) do
454+
fun.({value, path})
455+
else
456+
fun.(value)
457+
end
458+
end
441459
end

lib/instructor/sse_stream_parser.ex

+10-3
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ defmodule Instructor.SSEStreamParser do
1818
fn acc -> {[acc], nil} end,
1919
fn _acc -> nil end
2020
)
21-
|> Stream.filter(fn line -> line != "" end)
21+
|> Stream.filter(fn line -> String.trim(line) != "" end)
2222
|> Stream.transform(
2323
fn -> {:root, ""} end,
2424
fn
2525
"data: [DONE]" <> _, {:root, ""} ->
2626
{:halt, {:root, ""}}
2727

2828
"data: " <> data, {:root, ""} ->
29-
{[{:ok, Jason.decode!(data)}], {:root, ""}}
29+
{[{:ok, decode_json!(data)}], {:root, ""}}
3030

3131
"event: " <> _, {_, _} ->
3232
{[], {:root, ""}}
@@ -36,7 +36,7 @@ defmodule Instructor.SSEStreamParser do
3636
end,
3737
fn
3838
{:json, acc} ->
39-
{[{:error, Jason.decode!(acc)}], {:root, ""}}
39+
{[{:error, decode_json!(acc)}], {:root, ""}}
4040

4141
{:root, ""} ->
4242
{[], {:root, ""}}
@@ -51,4 +51,11 @@ defmodule Instructor.SSEStreamParser do
5151
raise "Error from LLM: #{inspect(error)}"
5252
end)
5353
end
54+
55+
defp decode_json!(data) do
56+
case Jason.decode(data) do
57+
{:ok, decoded} -> decoded
58+
{:error, err} -> raise "Error decoding: #{inspect(err)} \n\n #{inspect(data)}"
59+
end
60+
end
5461
end

lib/instructor/validator.ex

+16-15
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,6 @@ defmodule Instructor.Validator do
6060
end
6161
end
6262

63-
defmodule Validation do
64-
use Ecto.Schema
65-
66-
@doc """
67-
Validate if an attribute is correct and if not, return an error message
68-
"""
69-
@primary_key false
70-
embedded_schema do
71-
field(:valid?, :boolean)
72-
field(:reason, :string)
73-
end
74-
end
75-
7663
@doc """
7764
Validate a changeset field using a language model
7865
@@ -95,6 +82,20 @@ defmodule Instructor.Validator do
9582
end
9683
"""
9784
def validate_with_llm(changeset, field, statement, opts \\ []) do
85+
defmodule Validation do
86+
use Ecto.Schema
87+
use Instructor
88+
89+
@llm_doc """
90+
Validate if an attribute is correct and if not, return an error message
91+
"""
92+
@primary_key false
93+
embedded_schema do
94+
field(:valid?, :boolean)
95+
field(:reason, :string)
96+
end
97+
end
98+
9899
Ecto.Changeset.validate_change(changeset, field, fn field, value ->
99100
{:ok, response} =
100101
Instructor.chat_completion(
@@ -116,10 +117,10 @@ defmodule Instructor.Validator do
116117
)
117118

118119
case response do
119-
%Validation{valid?: true} ->
120+
%{valid?: true} ->
120121
[]
121122

122-
%Validation{reason: reason} ->
123+
%{reason: reason} ->
123124
[
124125
{field, "is invalid, #{reason}"}
125126
]

test/instructor_test.exs

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ defmodule InstructorTest do
268268
stream: true,
269269
response_model: {:array, President},
270270
messages: [
271-
%{role: "user", content: "What are the first 3 presidents of the United States?"}
271+
%{role: "user", content: "Who were the first 3 presidents of the United States?"}
272272
]
273273
)
274274
)

test/json_schema_test.exs

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ defmodule JSONSchemaTest do
7474

7575
test "includes documentation" do
7676
json_schema =
77-
JSONSchema.from_ecto_schema(InstructorTest.DemoWithDocumentation)
77+
JSONSchema.from_ecto_schema(InstructorTest.DemoWithUseInstructorAndNewDoc)
7878
|> Jason.decode!()
7979

8080
expected_json_schema =
8181
%{
82-
"description" => "Hello world\n",
82+
"description" => "Hello world",
8383
"properties" => %{
8484
"string" => %{
8585
"title" => "string",
@@ -88,7 +88,7 @@ defmodule JSONSchemaTest do
8888
}
8989
},
9090
"required" => ["string"],
91-
"title" => "InstructorTest.DemoWithDocumentation",
91+
"title" => "InstructorTest.DemoWithUseInstructorAndNewDoc",
9292
"type" => "object",
9393
"additionalProperties" => false
9494
}
@@ -140,7 +140,7 @@ defmodule JSONSchemaTest do
140140
test "basic types" do
141141
defmodule Demo do
142142
use Ecto.Schema
143-
143+
use Instructor
144144
# Be explicit about all fields in this test
145145
@primary_key false
146146
embedded_schema do

test/support/test_schemas.ex

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ defmodule InstructorTest.DemoWithUseInstructorButOldDoc do
4141
end
4242

4343
defmodule InstructorTest.DemoWithUseInstructorAndNewDoc do
44-
use Instructor
4544
use Ecto.Schema
45+
use Instructor
4646

4747
@llm_doc "Hello world"
48-
4948
@primary_key false
5049
embedded_schema do
5150
field(:string, :string)

0 commit comments

Comments
 (0)