diff --git a/dynamiq/connections/connections.py b/dynamiq/connections/connections.py index a2b75528..15bd80cb 100644 --- a/dynamiq/connections/connections.py +++ b/dynamiq/connections/connections.py @@ -999,7 +999,7 @@ def cursor_params(self) -> dict: class AWSRedshift(BaseConnection): host: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_HOST")) - port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5432)) + port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5439)) database: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_DATABASE", "db")) user: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_USER", "awsuser")) password: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PASSWORD", "password")) diff --git a/dynamiq/nodes/tools/__init__.py b/dynamiq/nodes/tools/__init__.py index 3481b03f..aabe8809 100644 --- a/dynamiq/nodes/tools/__init__.py +++ b/dynamiq/nodes/tools/__init__.py @@ -7,6 +7,6 @@ from .python import Python from .retriever import RetrievalTool from .scale_serp import ScaleSerpTool -from .sql_executor import SqlExecutor +from .sql_executor import SQLExecutor from .tavily import TavilyTool from .zenrows import ZenRowsTool diff --git a/dynamiq/nodes/tools/sql_executor.py b/dynamiq/nodes/tools/sql_executor.py index 9de1336c..85185b15 100644 --- a/dynamiq/nodes/tools/sql_executor.py +++ b/dynamiq/nodes/tools/sql_executor.py @@ -11,10 +11,10 @@ class SQLInputSchema(BaseModel): - query: str = Field(..., description="Parameter to provide a query that needs to be executed.") + query: str | None = Field(None, description="Parameter to provide a query that needs to be executed.") -class SqlExecutor(ConnectionNode): +class SQLExecutor(ConnectionNode): """ A tool for SQL query execution. @@ -33,6 +33,7 @@ class SqlExecutor(ConnectionNode): "You can use this tool to execute the query, specified for PostgreSQL, MySQL, Snowflake, AWS Redshift." ) connection: PostgreSQL | MySQL | Snowflake | AWSRedshift + query: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -44,8 +45,10 @@ def execute(self, input_data, config: RunnableConfig = None, **kwargs) -> dict[s config = ensure_config(config) self.run_on_node_execute_run(config.callbacks, **kwargs) - query = input_data.query + query = input_data.query or self.query try: + if not query: + raise ValueError("Query cannot be empty") cursor = self.client.cursor( **self.connection.cursor_params if not isinstance(self.connection, (PostgreSQL, AWSRedshift)) else {} ) diff --git a/examples/tools/use_sql.py b/examples/tools/use_sql.py index cc78655c..6b33a786 100644 --- a/examples/tools/use_sql.py +++ b/examples/tools/use_sql.py @@ -1,11 +1,11 @@ from dynamiq.connections import connections -from dynamiq.nodes.tools import SqlExecutor +from dynamiq.nodes.tools import SQLExecutor def basic_requests_snowflake_example(): snowflake_connection = connections.Snowflake() - snowflake_executor = SqlExecutor(connection=snowflake_connection) + snowflake_executor = SQLExecutor(connection=snowflake_connection) snowflake_insert = { "query": """INSERT INTO test1 (Name, Description) VALUES ('Name1', 'Description1'), ('Name2', 'Description2');""" @@ -22,7 +22,7 @@ def basic_requests_snowflake_example(): def basic_requests_mysql_example(): mysql_connection = connections.MySQL() - mysql_executor = SqlExecutor(connection=mysql_connection) + mysql_executor = SQLExecutor(connection=mysql_connection) mysql_insert = { "query": """ INSERT INTO test1 (`Name`, `Description`) diff --git a/tests/integration/nodes/tools/test_sql_executor.py b/tests/integration/nodes/tools/test_sql_executor.py index 210b4c7b..cf47b48a 100644 --- a/tests/integration/nodes/tools/test_sql_executor.py +++ b/tests/integration/nodes/tools/test_sql_executor.py @@ -1,7 +1,7 @@ import pytest from dynamiq.connections import connections -from dynamiq.nodes.tools.sql_executor import SqlExecutor +from dynamiq.nodes.tools.sql_executor import SQLExecutor from dynamiq.runnables import RunnableResult, RunnableStatus @@ -42,7 +42,7 @@ def mock_cursor_with_select(mocker, mock_fetchall_sql_response): ], ) def test_mysql_postgres_select_execute(mock_fetchall_sql_response, connection, mock_cursor_with_select): - sql_tool = SqlExecutor(connection=connection) + sql_tool = SQLExecutor(connection=connection) output = mock_fetchall_sql_response input_data = {"query": """select * from test1"""} @@ -86,7 +86,7 @@ def mock_cursor_with_none_description(mocker): ], ) def test_non_select_queries_execution(mock_fetchall_sql_response, connection, mock_cursor_with_none_description): - sql_tool = SqlExecutor(connection=connection) + sql_tool = SQLExecutor(connection=connection) output = [] input_data = {"query": """select * from test1"""}