Skip to content

Commit

Permalink
feat(sql router): new interface (#3630)
Browse files Browse the repository at this point in the history
SQLTODAG(string) -> DAG {name, sql, producers[]}
  • Loading branch information
aceforeverd authored Jan 8, 2024
1 parent ebabc22 commit b33ebe2
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 1 deletion.
1 change: 0 additions & 1 deletion hybridse/include/sdk/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <stdint.h>

#include <memory>
#include <ostream>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Copyright (c) 2023 OpenMLDB authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.sdk;

import java.util.ArrayList;

public class DAGNode {
public DAGNode(String name, String sql, ArrayList<DAGNode> producers) {
this.name = name;
this.sql = sql;
this.producers = producers;
}

public String name;
public String sql;
public ArrayList<DAGNode> producers;
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ PreparedStatement getBatchRequestPreparedStmt(String db, String sql,
NS.TableInfo getTableInfo(String db, String table);

List<String> getTableNames(String db);
/**
* Parse SQL query into DAG representation
*
* @param query SQL query string
* @throws SQLException exception if input query not valid for SQL parser
*/
DAGNode SQLToDAG(String query) throws SQLException;

void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,30 @@ public boolean updateOfflineTableInfo(NS.TableInfo info) {
public boolean refreshCatalog() {
return sqlRouter.RefreshCatalog();
}

@Override
public DAGNode SQLToDAG(String query) throws SQLException {
Status status = new Status();
final com._4paradigm.openmldb.DAGNode dag = sqlRouter.SQLToDAG(query, status);

try {
if (status.getCode() != 0) {
throw new SQLException(status.ToString());
}
return convertDAG(dag);
} finally {
dag.delete();
status.delete();
}
}

private static DAGNode convertDAG(com._4paradigm.openmldb.DAGNode dag) {
ArrayList<DAGNode> convertedProducers = new ArrayList<>();
for (com._4paradigm.openmldb.DAGNode producer : dag.getProducers()) {
final DAGNode converted = convertDAG(producer);
convertedProducers.add(converted);
}

return new DAGNode(dag.getName(), dag.getSql(), convertedProducers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com._4paradigm.openmldb.common.Pair;
import com._4paradigm.openmldb.proto.NS;
import com._4paradigm.openmldb.sdk.Column;
import com._4paradigm.openmldb.sdk.DAGNode;
import com._4paradigm.openmldb.sdk.Schema;
import com._4paradigm.openmldb.sdk.SdkOption;
import com._4paradigm.openmldb.sdk.SqlExecutor;
Expand Down Expand Up @@ -870,4 +871,60 @@ public void testMergeSQL() throws SQLException {
+ "(select db.main.id as merge_id_3, db.main.c1 as merge_c1_3, sum(c2) over w1 from main window w1 as (union (select \"\" as id, * from t1) partition by c1 order by c2 rows between unbounded preceding and current row)) as out3 "
+ "on out0.merge_id_0 = out3.merge_id_3 and out0.merge_c1_0 = out3.merge_c1_3;");
}

@Test(dataProvider = "executor")
public void testSQLToDag(SqlExecutor router) throws SQLException {
String sql = " WITH q1 as (WITH q3 as (select * from t1 LIMIT 10), q4 as (select * from t2) select * from q3 left join q4 on q3.id = q4.id),"
+
"q2 as (select * from t3)" +
"select * from q1 last join q2 on q1.id = q2.id";

DAGNode dag = router.SQLToDAG(sql);

Assert.assertEquals(dag.name, "");
Assert.assertEquals(dag.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" q1\n" +
" LAST JOIN\n" +
" q2\n" +
" ON q1.id = q2.id\n");
Assert.assertEquals(dag.producers.size(), 2);

DAGNode input1 = dag.producers.get(0);
Assert.assertEquals(input1.name, "q1");
Assert.assertEquals(input1.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" q3\n" +
" LEFT JOIN\n" +
" q4\n" +
" ON q3.id = q4.id\n");
Assert.assertEquals(2, input1.producers.size());

DAGNode input2 = dag.producers.get(1);
Assert.assertEquals(input2.name, "q2");
Assert.assertEquals(input2.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" t3\n");
Assert.assertEquals(input2.producers.size(), 0);

DAGNode q1In1 = input1.producers.get(0);
Assert.assertEquals(q1In1.producers.size(), 0);
Assert.assertEquals(q1In1.name, "q3");
Assert.assertEquals(q1In1.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" t1\n" +
"LIMIT 10\n");

DAGNode q1In2 = input1.producers.get(1);
Assert.assertEquals(q1In2.producers.size(), 0);
Assert.assertEquals(q1In2.name, "q4");
Assert.assertEquals(q1In2.sql, "SELECT\n" +
" *\n" +
"FROM\n" +
" t2\n");
}
}
83 changes: 83 additions & 0 deletions src/sdk/sql_router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
*/

#include "sdk/sql_router.h"

#include <map>

#include "absl/strings/substitute.h"
#include "base/ddl_parser.h"
#include "glog/logging.h"
#include "schema/schema_adapter.h"
#include "sdk/sql_cluster_router.h"
#include "zetasql/parser/parser.h"
#include "zetasql/public/error_helpers.h"
#include "zetasql/public/error_location.pb.h"

namespace openmldb::sdk {

Expand Down Expand Up @@ -274,4 +280,81 @@ std::vector<std::pair<std::string, std::string>> GetDependentTables(
return tables;
}

std::shared_ptr<DAGNode> QueryToDAG(const zetasql::ASTQuery* query, absl::string_view name) {
std::vector<std::shared_ptr<DAGNode>> producers;
if (query->with_clause() != nullptr) {
for (auto with_entry : query->with_clause()->with()) {
producers.push_back(QueryToDAG(with_entry->query(), with_entry->alias()->GetAsStringView()));
}
}

// SQL without WITH clause
std::string sql = zetasql::Unparse(query->query_expr());
if (query->order_by() != nullptr) {
absl::StrAppend(&sql, zetasql::Unparse(query->order_by()));
}
if (query->limit_offset() != nullptr) {
absl::StrAppend(&sql, zetasql::Unparse(query->limit_offset()));
}

return std::make_shared<DAGNode>(name, sql, producers);
}

std::shared_ptr<DAGNode> SQLRouter::SQLToDAG(const std::string& query, hybridse::sdk::Status* status) {
std::unique_ptr<zetasql::ParserOutput> parser_output;
zetasql::ParserOptions parser_opts;
zetasql::LanguageOptions language_opts;
language_opts.EnableLanguageFeature(zetasql::FEATURE_V_1_3_COLUMN_DEFAULT_VALUE);
parser_opts.set_language_options(&language_opts);
auto zetasql_status = zetasql::ParseStatement(query, parser_opts, &parser_output);
zetasql::ErrorLocation location;
if (!zetasql_status.ok()) {
zetasql::ErrorLocation location;
GetErrorLocation(zetasql_status, &location);
status->msg = zetasql::FormatError(zetasql_status);
status->code = hybridse::common::kSyntaxError;
return {};
}

auto stmt = parser_output->statement();
if (stmt == nullptr) {
status->msg = "not a statement";
status->code = hybridse::common::kSyntaxError;
return {};
}

if (stmt->node_kind() != zetasql::AST_QUERY_STATEMENT) {
status->msg = "not a query";
status->code = hybridse::common::kSyntaxError;
return {};
}

auto const query_stmt = stmt->GetAsOrNull<zetasql::ASTQueryStatement>();
if (query_stmt == nullptr) {
status->msg = "not a query";
status->code = hybridse::common::kSyntaxError;
return {};
}

status->code = hybridse::common::kOk;
return QueryToDAG(query_stmt->query(), "");
}

bool DAGNode::operator==(const DAGNode& rhs) const noexcept {
return name == rhs.name && sql == rhs.sql &&
absl::c_equal(producers, rhs.producers,
[](const std::shared_ptr<DAGNode>& left, const std::shared_ptr<DAGNode>& right) {
return left != nullptr && right != nullptr && *left == *right;
});
}

std::ostream& operator<<(std::ostream& os, const DAGNode& obj) { return os << obj.DebugString(); }

std::string DAGNode::DebugString() const {
return absl::Substitute("{$0, $1, [$2]}", name, sql,
absl::StrJoin(producers, ",", [](std::string* out, const std::shared_ptr<DAGNode>& e) {
absl::StrAppend(out, (e == nullptr ? "" : e->DebugString()));
}));
}

} // namespace openmldb::sdk
21 changes: 21 additions & 0 deletions src/sdk/sql_router.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ class ExplainInfo {
virtual const std::string& GetRequestDbName() = 0;
};

struct DAGNode {
DAGNode(absl::string_view name, absl::string_view sql) : name(name), sql(sql) {}
DAGNode(absl::string_view name, absl::string_view sql, const std::vector<std::shared_ptr<DAGNode>>& producers)
: name(name), sql(sql), producers(producers) {}

std::string name;
std::string sql;
std::vector<std::shared_ptr<DAGNode>> producers;

bool operator==(const DAGNode& op) const noexcept;

std::string DebugString() const;

friend std::ostream& operator<<(std::ostream& os, const DAGNode& obj);
};

class QueryFuture {
public:
QueryFuture() {}
Expand Down Expand Up @@ -234,6 +250,11 @@ class SQLRouter {
virtual bool IsOnlineMode() = 0;

virtual std::string GetDatabase() = 0;

// parse SQL query into DAG representation
//
// Optional CONFIG clause from SQL query statement is skipped in output DAG
std::shared_ptr<DAGNode> SQLToDAG(const std::string& query, hybridse::sdk::Status* status);
};

std::shared_ptr<SQLRouter> NewClusterSQLRouter(const SQLRouterOptions& options);
Expand Down
3 changes: 3 additions & 0 deletions src/sdk/sql_router_sdk.i
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
%template(VectorUint32) std::vector<uint32_t>;
%template(VectorString) std::vector<std::string>;

%shared_ptr(openmldb::sdk::DAGNode);
%{
#include "sdk/sql_router.h"
#include "sdk/result_set.h"
Expand Down Expand Up @@ -117,3 +118,5 @@ using openmldb::sdk::DefaultValueContainer;

%template(DBTable) std::pair<std::string, std::string>;
%template(DBTableVector) std::vector<std::pair<std::string, std::string>>;

%template(DAGNodeList) std::vector<std::shared_ptr<openmldb::sdk::DAGNode>>;
63 changes: 63 additions & 0 deletions src/sdk/sql_router_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,69 @@ TEST_F(SQLRouterTest, DDLParseMethodsCombineIndex) {
ddl_list.at(0));
}

TEST_F(SQLRouterTest, SQLToDAG) {
auto sql = R"(WITH q1 as (
WITH q3 as (select * from t1 ORDER BY ts),
q4 as (select * from t2 LIMIT 10)
select * from q3 left join q4 on q3.key = q4.key
),
q2 as (select * from t3)
select * from q1 last join q2 on q1.id = q2.id)";


hybridse::sdk::Status status;
auto dag = router_->SQLToDAG(sql, &status);
ASSERT_TRUE(status.IsOK());

std::string_view q3 = R"(SELECT
*
FROM
t1
ORDER BY ts
)";
std::string_view q4 = R"(SELECT
*
FROM
t2
LIMIT 10
)";
std::string_view q2 = R"(SELECT
*
FROM
t3
)";
std::string_view q1 = R"(SELECT
*
FROM
q3
LEFT JOIN
q4
ON q3.key = q4.key
)";
std::string_view q = R"(SELECT
*
FROM
q1
LAST JOIN
q2
ON q1.id = q2.id
)";

std::shared_ptr<DAGNode> dag_q3 = std::make_shared<DAGNode>("q3", q3);
std::shared_ptr<DAGNode> dag_q4 = std::make_shared<DAGNode>("q4", q4);

std::shared_ptr<DAGNode> dag_q1 =
std::make_shared<DAGNode>("q1", q1, std::vector<std::shared_ptr<DAGNode>>({dag_q3, dag_q4}));
std::shared_ptr<DAGNode> dag_q2 = std::make_shared<DAGNode>("q2", q2);

std::shared_ptr<DAGNode> expect =
std::make_shared<DAGNode>("", q, std::vector<std::shared_ptr<DAGNode>>({dag_q1, dag_q2}));

EXPECT_EQ(*dag, *expect);
}

} // namespace openmldb::sdk

int main(int argc, char** argv) {
Expand Down

0 comments on commit b33ebe2

Please sign in to comment.