From 830cad7bc089837da3d6872c7d400c14b869b097 Mon Sep 17 00:00:00 2001 From: Jun Yamog Date: Sat, 2 Nov 2024 11:42:24 +1300 Subject: [PATCH] core: fix CommaSeparatedListOutputParser to handle columns that may contain commas in it (#26365) - **Description:** Currently CommaSeparatedListOutputParser can't handle strings that may contain commas within a column. It would parse any commas as the delimiter. Ex. "foo, foo2", "bar", "baz" It will create 4 columns: "foo", "foo2", "bar", "baz" This should be 3 columns: "foo, foo2", "bar", "baz" - **Dependencies:** Added 2 additional imports, but they are built in python packages. import csv from io import StringIO - **Twitter handle:** @jkyamog - [ ] **Add tests and docs**: 1. added simple unit test test_multiple_items_with_comma --------- Co-authored-by: Erick Friis Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../langchain_core/output_parsers/list.py | 11 ++++++++++- .../output_parsers/test_list_parser.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 858ba86c79fa2..ebaca8f8ca94f 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -1,9 +1,11 @@ from __future__ import annotations +import csv import re from abc import abstractmethod from collections import deque from collections.abc import AsyncIterator, Iterator +from io import StringIO from typing import Optional as Optional from typing import TypeVar, Union @@ -162,7 +164,14 @@ def parse(self, text: str) -> list[str]: Returns: A list of strings. """ - return [part.strip() for part in text.split(",")] + try: + reader = csv.reader( + StringIO(text), quotechar='"', delimiter=",", skipinitialspace=True + ) + return [item for sublist in reader for item in sublist] + except csv.Error: + # keep old logic for backup + return [part.strip() for part in text.split(",")] @property def _type(self) -> str: diff --git a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py index 3f43edfa2aed8..11bd11b6a0b92 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py @@ -64,6 +64,25 @@ def test_multiple_items() -> None: assert list(parser.transform(iter([text]))) == [[a] for a in expected] +def test_multiple_items_with_comma() -> None: + """Test that a string with multiple comma-separated items with 1 item containing a + comma is parsed to a list.""" + parser = CommaSeparatedListOutputParser() + text = '"foo, foo2",bar,baz' + expected = ["foo, foo2", "bar", "baz"] + + assert parser.parse(text) == expected + assert add(parser.transform(t for t in text)) == expected + assert list(parser.transform(t for t in text)) == [[a] for a in expected] + assert list(parser.transform(t for t in text.splitlines(keepends=True))) == [ + [a] for a in expected + ] + assert list( + parser.transform(" " + t if i > 0 else t for i, t in enumerate(text.split(" "))) + ) == [[a] for a in expected] + assert list(parser.transform(iter([text]))) == [[a] for a in expected] + + def test_numbered_list() -> None: parser = NumberedListOutputParser() text1 = (