Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for VALUES query #1038

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
16 changes: 12 additions & 4 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod sort;
pub mod subquery_alias;
pub mod table_scan;
pub mod use_schema;
pub mod values;
pub mod window;

use datafusion_python::{
Expand Down Expand Up @@ -143,16 +144,23 @@ impl PyLogicalPlan {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::Window as PyWindow
pub fn window(&self) -> PyResult<window::PyWindow> {
/// LogicalPlan::TableScan as PyTableScan
pub fn table_scan(&self) -> PyResult<table_scan::PyTableScan> {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::TableScan as PyTableScan
pub fn table_scan(&self) -> PyResult<table_scan::PyTableScan> {
/// LogicalPlan::Values as PyValues
pub fn values(&self) -> PyResult<values::PyValues> {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::Window as PyWindow
pub fn window(&self) -> PyResult<window::PyWindow> {
to_py_plan(self.current_node.as_ref())
}

// Custom LogicalPlan Nodes

/// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable
pub fn create_memory_table(&self) -> PyResult<create_memory_table::PyCreateMemoryTable> {
to_py_plan(self.current_node.as_ref())
Expand Down
45 changes: 45 additions & 0 deletions dask_planner/src/sql/logical/values.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::sync::Arc;

use datafusion_python::datafusion_expr::{logical_plan::Values, LogicalPlan};
use pyo3::prelude::*;

use crate::{
expression::{py_expr_list, PyExpr},
sql::exceptions::py_type_err,
};

#[pyclass(name = "Values", module = "dask_planner", subclass)]
#[derive(Clone)]
pub struct PyValues {
values: Values,
plan: Arc<LogicalPlan>,
}

#[pymethods]
impl PyValues {
/// Creating a model requires that a subquery be passed to the CREATE MODEL
/// statement to be used to gather the dataset which should be used for the
/// model. This function returns that portion of the statement.
#[pyo3(name = "getValues")]
fn get_values(&self) -> PyResult<Vec<Vec<PyExpr>>> {
self.values
.values
.iter()
.map(|e| py_expr_list(&self.plan, e))
.collect()
}
}

impl TryFrom<LogicalPlan> for PyValues {
type Error = PyErr;

fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {
match logical_plan {
LogicalPlan::Values(values) => Ok(PyValues {
plan: Arc::new(LogicalPlan::Values(values.clone())),
values,
}),
_ => Err(py_type_err("unexpected plan")),
}
}
}
13 changes: 5 additions & 8 deletions dask_sql/physical/rel/logical/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org
from dask_planner.rust import LogicalPlan


class DaskValuesPlugin(BaseRelPlugin):
Expand All @@ -26,15 +26,12 @@ class DaskValuesPlugin(BaseRelPlugin):
data samples.
"""

class_name = "com.dask.sql.nodes.DaskValues"
class_name = "Values"

def convert(
self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context"
) -> DataContainer:
def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
# There should not be any input. This is the first step.
self.assert_inputs(rel, 0)

rex_expression_rows = list(rel.getTuples())
rex_expression_rows = rel.values().getValues()
rows = []
for rex_expression_row in rex_expression_rows:
# We convert each of the cells in the row
Expand All @@ -44,7 +41,7 @@ def convert(
# their index.
rows.append(
{
str(i): RexConverter.convert(rex_cell, None, context=context)
str(i): RexConverter.convert(rel, rex_cell, None, context=context)
for i, rex_cell in enumerate(rex_expression_row)
}
)
Expand Down
52 changes: 52 additions & 0 deletions tests/integration/test_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pandas as pd
import pytest

from tests.utils import assert_eq


def test_values(c):
result_df = c.sql(
"""
SELECT * FROM (VALUES (1, 2), (1, 3)) as tbl(column1, column2)
"""
)
expected_df = pd.DataFrame({"column1": [1, 1], "column2": [2, 3]})
assert_eq(result_df, expected_df, check_index=False)


def test_values_join(c):
result_df = c.sql(
"""
SELECT * FROM df_simple, (VALUES (1, 2), (1, 3)) as tbl(aa, bb)
WHERE a = aa
"""
)
expected_df = pd.DataFrame(
{"a": [1, 1], "b": [1.1, 1.1], "aa": [1, 1], "bb": [2, 3]}
)
assert_eq(result_df, expected_df, check_index=False)


@pytest.mark.xfail(reason="Datafusion doesn't handle values relations cleanly")
def test_values_join_alias(c):
result_df = c.sql(
"""
SELECT * FROM df_simple, (VALUES (1, 2), (1, 3)) as tbl(aa, bb)
WHERE a = tbl.aa
"""
)
expected_df = pd.DataFrame(
{"a": [1, 1], "b": [1.1, 1.1], "aa": [1, 1], "bb": [2, 3]}
)
assert_eq(result_df, expected_df, check_index=False)

result_df = c.sql(
"""
SELECT * FROM df_simple t1, (VALUES (1, 2), (1, 3)) as t2(a, b)
WHERE t1.a = t2.a
"""
)
expected_df = pd.DataFrame(
{"t1.a": [1, 1], "t1.b": [1.1, 1.1], "t2.aa": [1, 1], "t2.bb": [2, 3]}
)
assert_eq(result_df, expected_df, check_index=False)