Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: change SqlExecutor node name to SQLExecutor #126

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dynamiq/connections/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def cursor_params(self) -> dict:

class AWSRedshift(BaseConnection):
host: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_HOST"))
port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5432))
port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5439))
database: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_DATABASE", "db"))
user: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_USER", "awsuser"))
password: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PASSWORD", "password"))
Expand Down
2 changes: 1 addition & 1 deletion dynamiq/nodes/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from .python import Python
from .retriever import RetrievalTool
from .scale_serp import ScaleSerpTool
from .sql_executor import SqlExecutor
from .sql_executor import SQLExecutor
from .tavily import TavilyTool
from .zenrows import ZenRowsTool
9 changes: 6 additions & 3 deletions dynamiq/nodes/tools/sql_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


class SQLInputSchema(BaseModel):
query: str = Field(..., description="Parameter to provide a query that needs to be executed.")
query: str | None = Field(None, description="Parameter to provide a query that needs to be executed.")


class SqlExecutor(ConnectionNode):
class SQLExecutor(ConnectionNode):
"""
A tool for SQL query execution.

Expand All @@ -33,6 +33,7 @@ class SqlExecutor(ConnectionNode):
"You can use this tool to execute the query, specified for PostgreSQL, MySQL, Snowflake, AWS Redshift."
)
connection: PostgreSQL | MySQL | Snowflake | AWSRedshift
query: str | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -44,8 +45,10 @@ def execute(self, input_data, config: RunnableConfig = None, **kwargs) -> dict[s
config = ensure_config(config)
self.run_on_node_execute_run(config.callbacks, **kwargs)

query = input_data.query
query = input_data.query or self.query
try:
if not query:
raise ValueError("Query cannot be empty")
cursor = self.client.cursor(
**self.connection.cursor_params if not isinstance(self.connection, (PostgreSQL, AWSRedshift)) else {}
)
Expand Down
6 changes: 3 additions & 3 deletions examples/tools/use_sql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dynamiq.connections import connections
from dynamiq.nodes.tools import SqlExecutor
from dynamiq.nodes.tools import SQLExecutor


def basic_requests_snowflake_example():
snowflake_connection = connections.Snowflake()

snowflake_executor = SqlExecutor(connection=snowflake_connection)
snowflake_executor = SQLExecutor(connection=snowflake_connection)
snowflake_insert = {
"query": """INSERT INTO test1 (Name, Description)
VALUES ('Name1', 'Description1'), ('Name2', 'Description2');"""
Expand All @@ -22,7 +22,7 @@ def basic_requests_snowflake_example():
def basic_requests_mysql_example():
mysql_connection = connections.MySQL()

mysql_executor = SqlExecutor(connection=mysql_connection)
mysql_executor = SQLExecutor(connection=mysql_connection)
mysql_insert = {
"query": """
INSERT INTO test1 (`Name`, `Description`)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/nodes/tools/test_sql_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from dynamiq.connections import connections
from dynamiq.nodes.tools.sql_executor import SqlExecutor
from dynamiq.nodes.tools.sql_executor import SQLExecutor
from dynamiq.runnables import RunnableResult, RunnableStatus


Expand Down Expand Up @@ -42,7 +42,7 @@ def mock_cursor_with_select(mocker, mock_fetchall_sql_response):
],
)
def test_mysql_postgres_select_execute(mock_fetchall_sql_response, connection, mock_cursor_with_select):
sql_tool = SqlExecutor(connection=connection)
sql_tool = SQLExecutor(connection=connection)
output = mock_fetchall_sql_response
input_data = {"query": """select * from test1"""}

Expand Down Expand Up @@ -86,7 +86,7 @@ def mock_cursor_with_none_description(mocker):
],
)
def test_non_select_queries_execution(mock_fetchall_sql_response, connection, mock_cursor_with_none_description):
sql_tool = SqlExecutor(connection=connection)
sql_tool = SQLExecutor(connection=connection)
output = []
input_data = {"query": """select * from test1"""}

Expand Down
Loading