Skip to content

Commit

Permalink
fix llm_inputs duplication problem in intermediate_steps in SQLDataba…
Browse files Browse the repository at this point in the history
…seChain (langchain-ai#10279)

Use `.copy()` to fix the bug that the first `llm_inputs` element is
overwritten by the second `llm_inputs` element in `intermediate_steps`.

***Problem description:***
In [line 127](

https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L127C17-L127C17),
the `llm_inputs` of the sql generation step is appended as the first
element of `intermediate_steps`:
```
            intermediate_steps.append(llm_inputs)  # input: sql generation
```

However, `llm_inputs` is a mutable dict, it is updated in [line
179](https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/sql/base.py#L179)
for the final answer step:
```
                llm_inputs["input"] = input_text
```
Then, the updated `llm_inputs` is appended as another element of
`intermediate_steps` in [line
180](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L180):
```
                intermediate_steps.append(llm_inputs)  # input: final answer
```

As a result, the final `intermediate_steps` returned in [line
189](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L189C43-L189C43)
actually contains two same `llm_inputs` elements, i.e., the `llm_inputs`
for the sql generation step overwritten by the one for final answer step
by mistake. Users are not able to get the actual `llm_inputs` for the
sql generation step from `intermediate_steps`

Simply calling `.copy()` when appending `llm_inputs` to
`intermediate_steps` can solve this problem.
  • Loading branch information
xieqihui authored Oct 6, 2023
1 parent d78f418 commit 57ade13
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions libs/experimental/langchain_experimental/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _call(
llm_inputs[k] = inputs[k]
intermediate_steps: List = []
try:
intermediate_steps.append(llm_inputs) # input: sql generation
intermediate_steps.append(llm_inputs.copy()) # input: sql generation
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
Expand Down Expand Up @@ -180,7 +180,7 @@ def _call(
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
intermediate_steps.append(llm_inputs) # input: final answer
intermediate_steps.append(llm_inputs.copy()) # input: final answer
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
Expand Down

0 comments on commit 57ade13

Please sign in to comment.