Skip to content

Commit

Permalink
core: fix CommaSeparatedListOutputParser to handle columns that may c…
Browse files Browse the repository at this point in the history
…ontain commas in it (#26365)

- **Description:**
Currently CommaSeparatedListOutputParser can't handle strings that may
contain commas within a column. It would parse any commas as the
delimiter.
Ex. 
"foo, foo2", "bar", "baz"

It will create 4 columns: "foo", "foo2", "bar", "baz"

This should be 3 columns:

"foo, foo2", "bar", "baz"

- **Dependencies:**
Added 2 additional imports, but they are built in python packages.

import csv
from io import StringIO

- **Twitter handle:** @jkyamog

- [ ] **Add tests and docs**: 
1. added simple unit test test_multiple_items_with_comma

---------

Co-authored-by: Erick Friis <[email protected]>
Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
4 people authored Nov 1, 2024
1 parent 9fedb04 commit 830cad7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
11 changes: 10 additions & 1 deletion libs/core/langchain_core/output_parsers/list.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import csv
import re
from abc import abstractmethod
from collections import deque
from collections.abc import AsyncIterator, Iterator
from io import StringIO
from typing import Optional as Optional
from typing import TypeVar, Union

Expand Down Expand Up @@ -162,7 +164,14 @@ def parse(self, text: str) -> list[str]:
Returns:
A list of strings.
"""
return [part.strip() for part in text.split(",")]
try:
reader = csv.reader(
StringIO(text), quotechar='"', delimiter=",", skipinitialspace=True
)
return [item for sublist in reader for item in sublist]
except csv.Error:
# keep old logic for backup
return [part.strip() for part in text.split(",")]

@property
def _type(self) -> str:
Expand Down
19 changes: 19 additions & 0 deletions libs/core/tests/unit_tests/output_parsers/test_list_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def test_multiple_items() -> None:
assert list(parser.transform(iter([text]))) == [[a] for a in expected]


def test_multiple_items_with_comma() -> None:
"""Test that a string with multiple comma-separated items with 1 item containing a
comma is parsed to a list."""
parser = CommaSeparatedListOutputParser()
text = '"foo, foo2",bar,baz'
expected = ["foo, foo2", "bar", "baz"]

assert parser.parse(text) == expected
assert add(parser.transform(t for t in text)) == expected
assert list(parser.transform(t for t in text)) == [[a] for a in expected]
assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [
[a] for a in expected
]
assert list(
parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" ")))
) == [[a] for a in expected]
assert list(parser.transform(iter([text]))) == [[a] for a in expected]


def test_numbered_list() -> None:
parser = NumberedListOutputParser()
text1 = (
Expand Down

0 comments on commit 830cad7

Please sign in to comment.