diff --git a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py index f44d74e690654..e1d4e5ec664b9 100644 --- a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py +++ b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py @@ -24,6 +24,7 @@ from langchain_community.document_loaders.parsers.language.ruby import RubySegmenter from langchain_community.document_loaders.parsers.language.rust import RustSegmenter from langchain_community.document_loaders.parsers.language.scala import ScalaSegmenter +from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter from langchain_community.document_loaders.parsers.language.typescript import ( TypeScriptSegmenter, ) @@ -47,6 +48,7 @@ "php": "php", "ex": "elixir", "exs": "elixir", + "sql": "sql", } LANGUAGE_SEGMENTERS: Dict[str, Any] = { @@ -67,6 +69,7 @@ "java": JavaSegmenter, "php": PHPSegmenter, "elixir": ElixirSegmenter, + "sql": SQLSegmenter, } Language = Literal[ @@ -83,7 +86,6 @@ "ruby", "rust", "scala", - "swift", "markdown", "latex", "html", @@ -94,6 +96,7 @@ "lua", "perl", "elixir", + "sql", ] @@ -123,6 +126,7 @@ class LanguageParser(BaseBlobParser): - Ruby: "ruby" (*) - Rust: "rust" (*) - Scala: "scala" (*) + - SQL: "sql" (*) - TypeScript: "ts" (*) Items marked with (*) require the packages `tree_sitter` and diff --git a/libs/community/langchain_community/document_loaders/parsers/language/sql.py b/libs/community/langchain_community/document_loaders/parsers/language/sql.py new file mode 100644 index 0000000000000..1c11b7b363758 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/parsers/language/sql.py @@ -0,0 +1,65 @@ +from typing import TYPE_CHECKING + +from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501 + TreeSitterSegmenter, +) + +if TYPE_CHECKING: + from tree_sitter import Language + +CHUNK_QUERY = """ + [ + (create_table_statement) @create + (select_statement) @select + (insert_statement) @insert + (update_statement) @update + (delete_statement) @delete + ] +""" + + +class SQLSegmenter(TreeSitterSegmenter): + """Code segmenter for SQL. + This class uses Tree-sitter to segment SQL code into its + constituent statements (e.g., SELECT, CREATE TABLE). + It also provides functionality to extract these + statements and simplify the code into commented descriptions. + """ + + def get_language(self) -> "Language": + """Return the SQL language grammar for Tree-sitter.""" + from tree_sitter_languages import get_language + + return get_language("sql") + + def get_chunk_query(self) -> str: + """Return the Tree-sitter query for SQL segmentation.""" + return CHUNK_QUERY + + def extract_functions_classes(self) -> list[str]: + """Extract SQL statements from the code. + Ensures that all SQL statements end with a semicolon + for consistency. + """ + extracted = super().extract_functions_classes() + # Ensure all statements end with a semicolon + return [ + stmt.strip() + ";" if not stmt.strip().endswith(";") else stmt.strip() + for stmt in extracted + ] + + def simplify_code(self) -> str: + """Simplify the extracted SQL code into comments. + Converts SQL statements into commented descriptions + for easy readability. + """ + return "\n".join( + [ + f"-- Code for: {stmt.strip()}" + for stmt in self.extract_functions_classes() + ] + ) + + def make_line_comment(self, text: str) -> str: + """Create a line comment in SQL style.""" + return f"-- {text}" diff --git a/libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py new file mode 100644 index 0000000000000..37b22052ea243 --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py @@ -0,0 +1,61 @@ +import unittest + +import pytest + +from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter + + +@pytest.mark.requires("tree_sitter", "tree_sitter_languages") +class TestSQLSegmenter(unittest.TestCase): + """Unit tests for the SQLSegmenter class.""" + + def setUp(self) -> None: + """Set up example code and expected results for testing.""" + self.example_code = """ + CREATE TABLE users (id INT, name TEXT); + + -- A select query + SELECT id, name FROM users WHERE id = 1; + + INSERT INTO users (id, name) VALUES (2, 'Alice'); + + UPDATE users SET name = 'Bob' WHERE id = 2; + + DELETE FROM users WHERE id = 2; + """ + + self.expected_simplified_code = ( + "-- Code for: CREATE TABLE users (id INT, name TEXT);\n" + "-- Code for: SELECT id, name FROM users WHERE id = 1;\n" + "-- Code for: INSERT INTO users (id, name) VALUES (2, 'Alice');\n" + "-- Code for: UPDATE users SET name = 'Bob' WHERE id = 2;\n" + "-- Code for: DELETE FROM users WHERE id = 2;" + ) + + self.expected_extracted_code = [ + "CREATE TABLE users (id INT, name TEXT);", + "SELECT id, name FROM users WHERE id = 1;", + "INSERT INTO users (id, name) VALUES (2, 'Alice');", + "UPDATE users SET name = 'Bob' WHERE id = 2;", + "DELETE FROM users WHERE id = 2;", + ] + + def test_is_valid(self) -> None: + """Test the validity of SQL code.""" + # Valid SQL code should return True + self.assertTrue(SQLSegmenter("SELECT * FROM test").is_valid()) + # Invalid code (non-SQL text) should return False + self.assertFalse(SQLSegmenter("random text").is_valid()) + + def test_extract_functions_classes(self) -> None: + """Test extracting SQL statements from code.""" + segmenter = SQLSegmenter(self.example_code) + extracted_code = segmenter.extract_functions_classes() + # Verify the extracted code matches expected SQL statements + self.assertEqual(extracted_code, self.expected_extracted_code) + + def test_simplify_code(self) -> None: + """Test simplifying SQL code into commented descriptions.""" + segmenter = SQLSegmenter(self.example_code) + simplified_code = segmenter.simplify_code() + self.assertEqual(simplified_code, self.expected_simplified_code)