diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 0bd588d2a..8edb978e7 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -43,6 +43,10 @@ class QueryPostgresDTO(QueryDTO): connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field +class QueryPySparkDTO(QueryDTO): + connection_info: ConnectionUrl | PySparkConnectionInfo = connection_info_field + + class QuerySnowflakeDTO(QueryDTO): connection_info: SnowflakeConnectionInfo = connection_info_field @@ -109,6 +113,17 @@ class PostgresConnectionInfo(BaseModel): password: SecretStr +class PySparkConnectionInfo(BaseModel): + app_name: SecretStr = Field(examples=["wrenai"]) + master: SecretStr = Field( + default="local[*]", + description="Spark master URL (e.g., 'local[*]', 'spark://master:7077')", + ) + configs: dict[str, str] | None = Field( + default=None, description="Additional Spark configurations" + ) + + class SnowflakeConnectionInfo(BaseModel): user: SecretStr password: SecretStr @@ -137,6 +152,7 @@ class TrinoConnectionInfo(BaseModel): | MSSqlConnectionInfo | MySqlConnectionInfo | PostgresConnectionInfo + | PySparkConnectionInfo | SnowflakeConnectionInfo | TrinoConnectionInfo ) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index ea271f6aa..ef95ff4ba 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -7,6 +7,7 @@ import ibis from google.oauth2 import service_account from ibis import BaseBackend +from pyspark.sql import SparkSession from app.model import ( BigQueryConnectionInfo, @@ -16,6 +17,7 @@ MSSqlConnectionInfo, MySqlConnectionInfo, PostgresConnectionInfo, + PySparkConnectionInfo, QueryBigQueryDTO, QueryCannerDTO, QueryClickHouseDTO, @@ -23,6 +25,7 @@ QueryMSSqlDTO, QueryMySqlDTO, QueryPostgresDTO, + QueryPySparkDTO, QuerySnowflakeDTO, QueryTrinoDTO, SnowflakeConnectionInfo, @@ -37,6 +40,7 @@ class DataSource(StrEnum): mssql = auto() mysql = auto() postgres = auto() + pyspark = auto() snowflake = auto() trino = auto() @@ -60,6 +64,7 @@ class DataSourceExtension(Enum): mssql = QueryMSSqlDTO mysql = QueryMySqlDTO postgres = QueryPostgresDTO + pyspark = QueryPySparkDTO snowflake = QuerySnowflakeDTO trino = QueryTrinoDTO @@ -143,6 +148,20 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend: password=info.password.get_secret_value(), ) + @staticmethod + def get_pyspark_connection(info: PySparkConnectionInfo) -> BaseBackend: + builder = SparkSession.builder.appName(info.app_name.get_secret_value()).master( + info.master.get_secret_value() + ) + + if info.configs: + for key, value in info.configs.items(): + builder = builder.config(key, value) + + # Create or get existing Spark session + spark_session = builder.getOrCreate() + return ibis.pyspark.connect(session=spark_session) + @staticmethod def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend: return ibis.snowflake.connect( diff --git a/ibis-server/poetry.lock b/ibis-server/poetry.lock index 54e22efd7..9381e8266 100644 --- a/ibis-server/poetry.lock +++ b/ibis-server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1562,6 +1562,7 @@ db-dtypes = {version = ">=0.3,<2", optional = true, markers = "extra == \"bigque google-cloud-bigquery = {version = ">=3,<4", optional = true, markers = "extra == \"bigquery\""} google-cloud-bigquery-storage = {version = ">=2,<3", optional = true, markers = "extra == \"bigquery\""} numpy = {version = ">=1.23.2,<3", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"dask\" or extra == \"datafusion\" or extra == \"druid\" or extra == \"duckdb\" or extra == \"exasol\" or extra == \"flink\" or extra == \"impala\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"pandas\" or extra == \"polars\" or extra == \"postgres\" or extra == \"pyspark\" or extra == \"snowflake\" or extra == \"sqlite\" or extra == \"risingwave\" or extra == \"trino\""} +packaging = {version = ">=21.3,<25", optional = true, markers = "extra == \"dask\" or extra == \"duckdb\" or extra == \"oracle\" or extra == \"pandas\" or extra == \"polars\" or extra == \"pyspark\""} pandas = {version = ">=1.5.3,<3", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"dask\" or extra == \"datafusion\" or extra == \"druid\" or extra == \"duckdb\" or extra == \"exasol\" or extra == \"flink\" or extra == \"impala\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"pandas\" or extra == \"polars\" or extra == \"postgres\" or extra == \"pyspark\" or extra == \"snowflake\" or extra == \"sqlite\" or extra == \"risingwave\" or extra == \"trino\""} parsy = ">=2,<3" psycopg2 = {version = ">=2.8.4,<3", optional = true, markers = "extra == \"postgres\" or extra == \"risingwave\""} @@ -1570,6 +1571,7 @@ pyarrow-hotfix = {version = ">=0.4,<1", optional = true, markers = "extra == \"b pydata-google-auth = {version = ">=1.4.0,<2", optional = true, markers = "extra == \"bigquery\""} pymysql = {version = ">=1,<2", optional = true, markers = "extra == \"mysql\""} pyodbc = {version = ">=4.0.39,<6", optional = true, markers = "extra == \"mssql\""} +pyspark = {version = ">=3.3.3,<4", optional = true, markers = "extra == \"pyspark\""} python-dateutil = ">=2.8.2,<3" pytz = ">=2022.7" rich = {version = ">=12.4.4,<14", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"dask\" or extra == \"datafusion\" or extra == \"druid\" or extra == \"duckdb\" or extra == \"exasol\" or extra == \"flink\" or extra == \"impala\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"pandas\" or extra == \"polars\" or extra == \"postgres\" or extra == \"pyspark\" or extra == \"snowflake\" or extra == \"sqlite\" or extra == \"risingwave\" or extra == \"trino\""} @@ -2424,6 +2426,17 @@ files = [ {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, ] +[[package]] +name = "py4j" +version = "0.10.9.7" +description = "Enables Python programs to dynamically access arbitrary Java objects" +optional = false +python-versions = "*" +files = [ + {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"}, + {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"}, +] + [[package]] name = "pyarrow" version = "17.0.0" @@ -2873,6 +2886,26 @@ cryptography = ">=41.0.5,<44" docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"] test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] +[[package]] +name = "pyspark" +version = "3.5.3" +description = "Apache Spark Python API" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyspark-3.5.3.tar.gz", hash = "sha256:68b7cc0c0c570a7d8644f49f40d2da8709b01d30c9126cc8cf93b4f84f3d9747"}, +] + +[package.dependencies] +py4j = "0.10.9.7" + +[package.extras] +connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.56.0)", "grpcio-status (>=1.56.0)", "numpy (>=1.15,<2)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] +ml = ["numpy (>=1.15,<2)"] +mllib = ["numpy (>=1.15,<2)"] +pandas-on-spark = ["numpy (>=1.15,<2)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] +sql = ["numpy (>=1.15,<2)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] + [[package]] name = "pytest" version = "8.3.3" @@ -4146,4 +4179,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "3d79487fb66b6bcee3b45f90db93e2ea2747b25888972f3dcbf2e3f0a69052b1" +content-hash = "78181092bfc9a825884f9eb771af6fc72a82e02ab41b08aa546157ceda5677e5" diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index 594d5a77a..66df718ad 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -16,6 +16,7 @@ ibis-framework = { version = "9.5.0", extras = [ "mssql", "mysql", "postgres", + "pyspark", "snowflake", "trino", ] } @@ -42,6 +43,7 @@ sqlalchemy = "2.0.36" pre-commit = "4.0.1" ruff = "0.8.0" trino = ">=0.321,<1" +pyspark = "3.5.1" psycopg2 = ">=2.8.4,<3" clickhouse-connect = "0.8.7" @@ -54,6 +56,7 @@ markers = [ "mssql: mark a test as a mssql test", "mysql: mark a test as a mysql test", "postgres: mark a test as a postgres test", + "pyspark: mark a test as a pyspark test", "snowflake: mark a test as a snowflake test", "trino: mark a test as a trino test", "beta: mark a test as a test for beta versions of the engine", diff --git a/ibis-server/tests/routers/v2/connector/test_pyspark.py b/ibis-server/tests/routers/v2/connector/test_pyspark.py new file mode 100644 index 000000000..6e74479c7 --- /dev/null +++ b/ibis-server/tests/routers/v2/connector/test_pyspark.py @@ -0,0 +1,191 @@ +import base64 + +# import os +import orjson +import pytest +from fastapi.testclient import TestClient + +from app.main import app +from app.model.validator import rules + +pytestmark = pytest.mark.pyspark + +base_url = "/v2/connector/pyspark" + +connection_info = { + "app_name": "MyApp", + "master": "local", +} + +manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "Orders", + "properties": {}, + "refSql": "select * from tpch.orders", + "columns": [ + {"name": "orderkey", "expression": "O_ORDERKEY", "type": "integer"}, + {"name": "custkey", "expression": "O_CUSTKEY", "type": "integer"}, + { + "name": "orderstatus", + "expression": "O_ORDERSTATUS", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "O_TOTALPRICE", + "type": "float", + }, + {"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"}, + { + "name": "order_cust_key", + "expression": "concat(O_ORDERKEY, '_', O_CUSTKEY)", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "cast('2024-01-01T23:59:59' as timestamp)", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", + "type": "timestamp", + }, + { + "name": "test_null_time", + "expression": "cast(NULL as timestamp)", + "type": "timestamp", + }, + ], + "primaryKey": "orderkey", + }, + ], +} + + +@pytest.fixture +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +with TestClient(app) as client: + # def test_query(manifest_str): + # response = client.post( + # url=f"{base_url}/query", + # json={ + # "connectionInfo": connection_info, + # "manifestStr": manifest_str, + # "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1', + # }, + # ) + # assert response.status_code == 200 + # result = response.json() + # assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + # assert len(result["data"]) == 1 + # assert result["data"][0] == [ + # 1, + # 36901, + # "O", + # "173665.47", + # "1996-01-02", + # "1_36901", + # "2024-01-01 23:59:59.000000", + # "2024-01-01 23:59:59.000000 UTC", + # None, + # ] + # assert result["dtypes"] == { + # "orderkey": "int64", + # "custkey": "int64", + # "orderstatus": "object", + # "totalprice": "object", + # "orderdate": "object", + # "order_cust_key": "object", + # "timestamp": "object", + # "timestamptz": "object", + # "test_null_time": "datetime64[ns]", + # } + + def test_query_without_manifest(): + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" + + def test_query_without_sql(manifest_str): + response = client.post( + url=f"{base_url}/query", + json={"connectionInfo": connection_info, "manifestStr": manifest_str}, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" + + def test_query_without_connection_info(manifest_str): + response = client.post( + url=f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" + + # def test_query_with_dry_run(manifest_str): + # response = client.post( + # url=f"{base_url}/query", + # params={"dryRun": True}, + # json={ + # "connectionInfo": connection_info, + # "manifestStr": manifest_str, + # "sql": 'SELECT * FROM "Orders" LIMIT 1', + # }, + # ) + # assert response.status_code == 204 + + def test_query_with_dry_run_and_invalid_sql(manifest_str): + response = client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM X", + }, + ) + assert response.status_code == 422 + assert response.text is not None + + def test_validate_with_unknown_rule(manifest_str): + response = client.post( + url=f"{base_url}/validate/unknown_rule", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "Orders", "columnName": "orderkey"}, + }, + ) + assert response.status_code == 404 + assert ( + response.text + == f"The rule `unknown_rule` is not in the rules, rules: {rules}" + )