Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql router): add interface for sql to dag #3630

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion hybridse/include/sdk/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <stdint.h>

#include <memory>
#include <ostream>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Copyright (c) 2023 OpenMLDB authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.sdk;

import java.util.ArrayList;

public class DAGNode {
public DAGNode(String name, String sql, ArrayList<DAGNode> producers) {
this.name = name;
this.sql = sql;
this.producers = producers;
}

public String name;
public String sql;
public ArrayList<DAGNode> producers;
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ PreparedStatement getBatchRequestPreparedStmt(String db, String sql,
NS.TableInfo getTableInfo(String db, String table);

List<String> getTableNames(String db);
/**
* Parse SQL query into DAG representation
*
* @param query SQL query string
* @throws SQLException exception if input query not valid for SQL parser
*/
DAGNode SQLToDAG(String query) throws SQLException;

void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -664,4 +664,30 @@
public boolean refreshCatalog() {
return sqlRouter.RefreshCatalog();
}

@Override
public DAGNode SQLToDAG(String query) throws SQLException {
Status status = new Status();
final com._4paradigm.openmldb.DAGNode dag = sqlRouter.SQLToDAG(query, status);

try {
if (status.getCode() != 0) {
throw new SQLException(status.ToString());

Check warning on line 675 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java#L675

Added line #L675 was not covered by tests
}
return convertDAG(dag);
} finally {
dag.delete();
status.delete();
}
}

private static DAGNode convertDAG(com._4paradigm.openmldb.DAGNode dag) {
ArrayList<DAGNode> convertedProducers = new ArrayList<>();
for (com._4paradigm.openmldb.DAGNode producer : dag.getProducers()) {
final DAGNode converted = convertDAG(producer);
convertedProducers.add(converted);
}

return new DAGNode(dag.getName(), dag.getSql(), convertedProducers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com._4paradigm.openmldb.common.Pair;
import com._4paradigm.openmldb.proto.NS;
import com._4paradigm.openmldb.sdk.Column;
import com._4paradigm.openmldb.sdk.DAGNode;
import com._4paradigm.openmldb.sdk.Schema;
import com._4paradigm.openmldb.sdk.SdkOption;
import com._4paradigm.openmldb.sdk.SqlExecutor;
Expand Down Expand Up @@ -858,4 +859,60 @@ public void testMergeSQL() throws SQLException {
+ "(select db.main.id as merge_id_3, db.main.c1 as merge_c1_3, sum(c2) over w1 from main window w1 as (union (select \"\" as id, * from t1) partition by c1 order by c2 rows between unbounded preceding and current row)) as out3 "
+ "on out0.merge_id_0 = out3.merge_id_3 and out0.merge_c1_0 = out3.merge_c1_3;");
}

@Test(dataProvider = "executor")
public void testSQLToDag(SqlExecutor router) throws SQLException {
String sql = " WITH q1 as (WITH q3 as (select * from t1 LIMIT 10), q4 as (select * from t2) select * from q3 left join q4 on q3.id = q4.id),"
+
"q2 as (select * from t3)" +
"select * from q1 last join q2 on q1.id = q2.id";

DAGNode dag = router.SQLToDAG(sql);

Assert.assertEquals(dag.name, "");
Assert.assertEquals(dag.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" q1\n" +
" LAST JOIN\n" +
" q2\n" +
" ON q1.id = q2.id\n");
Assert.assertEquals(dag.producers.size(), 2);

DAGNode input1 = dag.producers.get(0);
Assert.assertEquals(input1.name, "q1");
Assert.assertEquals(input1.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" q3\n" +
" LEFT JOIN\n" +
" q4\n" +
" ON q3.id = q4.id\n");
Assert.assertEquals(2, input1.producers.size());

DAGNode input2 = dag.producers.get(1);
Assert.assertEquals(input2.name, "q2");
Assert.assertEquals(input2.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" t3\n");
Assert.assertEquals(input2.producers.size(), 0);

DAGNode q1In1 = input1.producers.get(0);
Assert.assertEquals(q1In1.producers.size(), 0);
Assert.assertEquals(q1In1.name, "q3");
Assert.assertEquals(q1In1.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" t1\n" +
"LIMIT 10\n");

DAGNode q1In2 = input1.producers.get(1);
Assert.assertEquals(q1In2.producers.size(), 0);
Assert.assertEquals(q1In2.name, "q4");
Assert.assertEquals(q1In2.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" t2\n");
}
}
83 changes: 83 additions & 0 deletions src/sdk/sql_router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
*/

#include "sdk/sql_router.h"

#include <map>

#include "absl/strings/substitute.h"
#include "base/ddl_parser.h"
#include "glog/logging.h"
#include "schema/schema_adapter.h"
#include "sdk/sql_cluster_router.h"
#include "zetasql/parser/parser.h"
#include "zetasql/public/error_helpers.h"
#include "zetasql/public/error_location.pb.h"

namespace openmldb::sdk {

Expand Down Expand Up @@ -274,4 +280,81 @@
return tables;
}

std::shared_ptr<DAGNode> QueryToDAG(const zetasql::ASTQuery* query, absl::string_view name) {
std::vector<std::shared_ptr<DAGNode>> producers;
if (query->with_clause() != nullptr) {
for (auto with_entry : query->with_clause()->with()) {
producers.push_back(QueryToDAG(with_entry->query(), with_entry->alias()->GetAsStringView()));

Check warning on line 287 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L283-L287

Added lines #L283 - L287 were not covered by tests
}
}

// SQL without WITH clause
std::string sql = zetasql::Unparse(query->query_expr());
if (query->order_by() != nullptr) {
absl::StrAppend(&sql, zetasql::Unparse(query->order_by()));

Check warning on line 294 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L292-L294

Added lines #L292 - L294 were not covered by tests
}
if (query->limit_offset() != nullptr) {
absl::StrAppend(&sql, zetasql::Unparse(query->limit_offset()));

Check warning on line 297 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L296-L297

Added lines #L296 - L297 were not covered by tests
}

return std::make_shared<DAGNode>(name, sql, producers);

Check warning on line 300 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L300

Added line #L300 was not covered by tests
}

std::shared_ptr<DAGNode> SQLRouter::SQLToDAG(const std::string& query, hybridse::sdk::Status* status) {
std::unique_ptr<zetasql::ParserOutput> parser_output;
zetasql::ParserOptions parser_opts;
zetasql::LanguageOptions language_opts;
language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE);
parser_opts.set_language_options(&language_opts);
auto zetasql_status = zetasql::ParseStatement(query, parser_opts, &parser_output);
zetasql::ErrorLocation location;
if (!zetasql_status.ok()) {
zetasql::ErrorLocation location;
GetErrorLocation(zetasql_status, &location);
status->msg = zetasql::FormatError(zetasql_status);
status->code = hybridse::common::kSyntaxError;
return {};

Check warning on line 316 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L303-L316

Added lines #L303 - L316 were not covered by tests
}

auto stmt = parser_output->statement();
if (stmt == nullptr) {
status->msg = "not a statement";
status->code = hybridse::common::kSyntaxError;
return {};

Check warning on line 323 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L319-L323

Added lines #L319 - L323 were not covered by tests
}

if (stmt->node_kind() != zetasql::AST_QUERY_STATEMENT) {
status->msg = "not a query";
status->code = hybridse::common::kSyntaxError;
return {};

Check warning on line 329 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L326-L329

Added lines #L326 - L329 were not covered by tests
}

auto const query_stmt = stmt->GetAsOrNull<zetasql::ASTQueryStatement>();
if (query_stmt == nullptr) {
status->msg = "not a query";
status->code = hybridse::common::kSyntaxError;
return {};

Check warning on line 336 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L332-L336

Added lines #L332 - L336 were not covered by tests
}

status->code = hybridse::common::kOk;
return QueryToDAG(query_stmt->query(), "");

Check warning on line 340 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L339-L340

Added lines #L339 - L340 were not covered by tests
}

bool DAGNode::operator==(const DAGNode& rhs) const noexcept {
return name == rhs.name && sql == rhs.sql &&
absl::c_equal(producers, rhs.producers,
[](const std::shared_ptr<DAGNode>& left, const std::shared_ptr<DAGNode>& right) {
return left != nullptr && right != nullptr && *left == *right;
});

Check warning on line 348 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L343-L348

Added lines #L343 - L348 were not covered by tests
}

std::ostream& operator<<(std::ostream& os, const DAGNode& obj) { return os << obj.DebugString(); }

Check warning on line 351 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L351

Added line #L351 was not covered by tests

std::string DAGNode::DebugString() const {
return absl::Substitute("{$0, $1, [$2]}", name, sql,
absl::StrJoin(producers, ",", [](std::string* out, const std::shared_ptr<DAGNode>& e) {
absl::StrAppend(out, (e == nullptr ? "" : e->DebugString()));
}));

Check warning on line 357 in src/sdk/sql_router.cc

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.cc#L353-L357

Added lines #L353 - L357 were not covered by tests
}

} // namespace openmldb::sdk
21 changes: 21 additions & 0 deletions src/sdk/sql_router.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@
virtual const std::string& GetRequestDbName() = 0;
};

struct DAGNode {
DAGNode(absl::string_view name, absl::string_view sql) : name(name), sql(sql) {}
DAGNode(absl::string_view name, absl::string_view sql, const std::vector<std::shared_ptr<DAGNode>>& producers)
: name(name), sql(sql), producers(producers) {}

Check warning on line 86 in src/sdk/sql_router.h

View check run for this annotation

Codecov / codecov/patch

src/sdk/sql_router.h#L85-L86

Added lines #L85 - L86 were not covered by tests

std::string name;
std::string sql;
std::vector<std::shared_ptr<DAGNode>> producers;

bool operator==(const DAGNode& op) const noexcept;

std::string DebugString() const;

friend std::ostream& operator<<(std::ostream& os, const DAGNode& obj);
};

class QueryFuture {
public:
QueryFuture() {}
Expand Down Expand Up @@ -234,6 +250,11 @@
virtual bool IsOnlineMode() = 0;

virtual std::string GetDatabase() = 0;

// parse SQL query into DAG representation
//
// Optional CONFIG clause from SQL query statement is skipped in output DAG
std::shared_ptr<DAGNode> SQLToDAG(const std::string& query, hybridse::sdk::Status* status);
};

std::shared_ptr<SQLRouter> NewClusterSQLRouter(const SQLRouterOptions& options);
Expand Down
3 changes: 3 additions & 0 deletions src/sdk/sql_router_sdk.i
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
%template(VectorUint32) std::vector<uint32_t>;
%template(VectorString) std::vector<std::string>;

%shared_ptr(openmldb::sdk::DAGNode);
%{
#include "sdk/sql_router.h"
#include "sdk/result_set.h"
Expand Down Expand Up @@ -117,3 +118,5 @@ using openmldb::sdk::DefaultValueContainer;

%template(DBTable) std::pair<std::string, std::string>;
%template(DBTableVector) std::vector<std::pair<std::string, std::string>>;

%template(DAGNodeList) std::vector<std::shared_ptr<openmldb::sdk::DAGNode>>;
63 changes: 63 additions & 0 deletions src/sdk/sql_router_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,69 @@ TEST_F(SQLRouterTest, DDLParseMethodsCombineIndex) {
ddl_list.at(0));
}

TEST_F(SQLRouterTest, SQLToDAG) {
auto sql = R"(WITH q1 as (
WITH q3 as (select * from t1 ORDER BY ts),
q4 as (select * from t2 LIMIT 10)

select * from q3 left join q4 on q3.key = q4.key
),
q2 as (select * from t3)

select * from q1 last join q2 on q1.id = q2.id)";


hybridse::sdk::Status status;
auto dag = router_->SQLToDAG(sql, &status);
ASSERT_TRUE(status.IsOK());

std::string_view q3 = R"(SELECT
*
FROM
t1
ORDER BY ts
)";
std::string_view q4 = R"(SELECT
*
FROM
t2
LIMIT 10
)";
std::string_view q2 = R"(SELECT
*
FROM
t3
)";
std::string_view q1 = R"(SELECT
*
FROM
q3
LEFT JOIN
q4
ON q3.key = q4.key
)";
std::string_view q = R"(SELECT
*
FROM
q1
LAST JOIN
q2
ON q1.id = q2.id
)";

std::shared_ptr<DAGNode> dag_q3 = std::make_shared<DAGNode>("q3", q3);
std::shared_ptr<DAGNode> dag_q4 = std::make_shared<DAGNode>("q4", q4);

std::shared_ptr<DAGNode> dag_q1 =
std::make_shared<DAGNode>("q1", q1, std::vector<std::shared_ptr<DAGNode>>({dag_q3, dag_q4}));
std::shared_ptr<DAGNode> dag_q2 = std::make_shared<DAGNode>("q2", q2);

std::shared_ptr<DAGNode> expect =
std::make_shared<DAGNode>("", q, std::vector<std::shared_ptr<DAGNode>>({dag_q1, dag_q2}));

EXPECT_EQ(*dag, *expect);
}

} // namespace openmldb::sdk

int main(int argc, char** argv) {
Expand Down
Loading