From b7e592c2a46cbb480e3233d97c47b66a66636102 Mon Sep 17 00:00:00 2001 From: yangwucheng Date: Mon, 1 Jul 2024 21:50:10 +0800 Subject: [PATCH] feat(open-mysql-db): pandas support (#3868) * feat(open-mysql-db): refactor 1. remove unnecessary instance var port 2. fix cause null bug 3. remove unnecessary throws 4. fix ctx.close() sequence bug 5. config sessionTimeout and requestTimeout 6. add docs of SqlEngine * feat(open-mysql-db): refactor * feat(open-mysql-db): revert passsword * feat(open-mysql-db): mock commit and schema table count * feat(open-mysql-db): replace data type text with string * feat(open-mysql-db): remove null --------- Co-authored-by: yangwucheng --- .../open-mysql-db/python-testcases/main.py | 33 ++++++++++ .../java/cn/paxos/mysql/MySqlListener.java | 3 + .../mysql/server/OpenmldbMysqlServer.java | 61 ++++++++++++++++++- 3 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 extensions/open-mysql-db/python-testcases/main.py diff --git a/extensions/open-mysql-db/python-testcases/main.py b/extensions/open-mysql-db/python-testcases/main.py new file mode 100644 index 00000000000..a673e63def3 --- /dev/null +++ b/extensions/open-mysql-db/python-testcases/main.py @@ -0,0 +1,33 @@ +import pandas as pd +from sqlalchemy import create_engine + +if __name__ == '__main__': + # Create a Pandas DataFrame (replace this with your actual data) + data = {'id': [1, 2, 3], + 'name': ['Alice', 'Bob', 'Charlie'], + 'age': [25, 30, 35], + 'score': [1.1, 2.2, 3.3], + 'ts': [pd.Timestamp.utcnow().timestamp(), pd.Timestamp.utcnow().timestamp(), + pd.Timestamp.utcnow().timestamp()], + 'dt': [pd.to_datetime('20240101', format='%Y%m%d'), pd.to_datetime('20240201', format='%Y%m%d'), + pd.to_datetime('20240301', format='%Y%m%d')], + } + df = pd.DataFrame(data) + + # Create a MySQL database engine using SQLAlchemy + engine = create_engine('mysql+pymysql://root:root@127.0.0.1:3307/demo_db') + + # Replace 'username', 'password', 'host', and 'db_name' with your actual database credentials + + # Define the name of the table in the database where you want to write the data + table_name = 'demo_table1' + + # Write the DataFrame 'df' into the MySQL table + df.to_sql(table_name, engine, if_exists='replace', index=False) + + # 'if_exists' parameter options: + # - 'fail': If the table already exists, an error will be raised. + # - 'replace': If the table already exists, it will be replaced. + # - 'append': If the table already exists, data will be appended to it. + + print("Data written to MySQL table successfully!") diff --git a/extensions/open-mysql-db/src/main/java/cn/paxos/mysql/MySqlListener.java b/extensions/open-mysql-db/src/main/java/cn/paxos/mysql/MySqlListener.java index a775a3ac78f..e9dc36d4715 100644 --- a/extensions/open-mysql-db/src/main/java/cn/paxos/mysql/MySqlListener.java +++ b/extensions/open-mysql-db/src/main/java/cn/paxos/mysql/MySqlListener.java @@ -217,6 +217,9 @@ private void handleQuery( && !queryStringWithoutComment.startsWith("set @@execute_mode=")) { // ignore SET command ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build()); + } else if (queryStringWithoutComment.equalsIgnoreCase("COMMIT")) { + // ignore COMMIT command + ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build()); } else if (useDbMatcher.matches()) { sqlEngine.useDatabase(getConnectionId(ctx), useDbMatcher.group(1)); ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build()); diff --git a/extensions/open-mysql-db/src/main/java/com/_4paradigm/openmldb/mysql/server/OpenmldbMysqlServer.java b/extensions/open-mysql-db/src/main/java/com/_4paradigm/openmldb/mysql/server/OpenmldbMysqlServer.java index 0e9a0ce8ec3..b5af68c7e75 100644 --- a/extensions/open-mysql-db/src/main/java/com/_4paradigm/openmldb/mysql/server/OpenmldbMysqlServer.java +++ b/extensions/open-mysql-db/src/main/java/com/_4paradigm/openmldb/mysql/server/OpenmldbMysqlServer.java @@ -10,6 +10,7 @@ import com._4paradigm.openmldb.jdbc.SQLResultSet; import com._4paradigm.openmldb.mysql.mock.MockResult; import com._4paradigm.openmldb.mysql.util.TypeUtil; +import com._4paradigm.openmldb.proto.NS; import com._4paradigm.openmldb.sdk.Column; import com._4paradigm.openmldb.sdk.Schema; import com._4paradigm.openmldb.sdk.SdkOption; @@ -56,6 +57,12 @@ public class OpenmldbMysqlServer { Pattern.compile( "(?i)SELECT COUNT\\(\\*\\) FROM information_schema\\.TABLES WHERE TABLE_SCHEMA = '(.+)'"); + // SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'xzs' AND table_name = + // 't_exam_paper' + private final Pattern selectCountSchemaTablesPattern = + Pattern.compile( + "(?i)SELECT COUNT\\(\\*\\) FROM information_schema\\.TABLES WHERE TABLE_SCHEMA = '(.+)' AND table_name = '(.+)'"); + // SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = 'xzs' private final Pattern selectCountColumnsPattern = Pattern.compile( @@ -182,6 +189,10 @@ public void query( return; } + if (mockSelectSchemaTableCount(connectionId, resultSetWriter, sql)) { + return; + } + // This mock must execute before mockPatternQuery // SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = 'demo_db' // UNION SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = @@ -240,6 +251,7 @@ public void query( return; } + String originalSql = sql; if (sql.startsWith("SHOW FULL TABLES")) { // SHOW FULL TABLES WHERE Table_type != 'VIEW' Matcher showTablesFromDbMatcher = showTablesFromDbPattern.matcher(sql); @@ -250,6 +262,14 @@ public void query( } else { sql = "SHOW TABLES"; } + } else if (sql.matches("(?i)(?s)^\\s*CREATE TABLE.*$")) { + // convert data type TEXT to STRING + sql = sql.replaceAll("(?i) TEXT", " STRING"); + // sql = sql.replaceAll("(?i) DATETIME", " DATE"); + if (!sql.toLowerCase().contains(" not null") + && sql.toLowerCase().contains(" null")) { + sql = sql.replaceAll("(?i) null", ""); + } } else { Matcher crateDatabaseMatcher = createDatabasePattern.matcher(sql); Matcher selectLimitMatcher = selectLimitPattern.matcher(sql); @@ -264,7 +284,7 @@ public void query( if (sql.toLowerCase().startsWith("select") || sql.toLowerCase().startsWith("show")) { SQLResultSet resultSet = (SQLResultSet) stmt.getResultSet(); - outputResultSet(resultSetWriter, resultSet, sql); + outputResultSet(resultSetWriter, resultSet, originalSql); } System.out.println("Success to execute OpenMLDB SQL: " + sql); @@ -622,6 +642,36 @@ private boolean mockPatternQuery(ResultSetWriter resultSetWriter, String sql) { return false; } + private boolean mockSelectSchemaTableCount( + int connectionId, ResultSetWriter resultSetWriter, String sql) throws SQLException { + // SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'xzs' AND + // table_name = 't_exam_paper' + Matcher selectCountSchemaTablesMatcher = selectCountSchemaTablesPattern.matcher(sql); + if (selectCountSchemaTablesMatcher.matches()) { + // COUNT(*) + List columns = new ArrayList<>(); + columns.add(new QueryResultColumn("COUNT(*)", "VARCHAR(255)")); + resultSetWriter.writeColumns(columns); + + List row; + String dbName = selectCountSchemaTablesMatcher.group(1); + String tableName = selectCountSchemaTablesMatcher.group(2); + row = new ArrayList<>(); + NS.TableInfo tableInfo = + sqlClusterExecutorMap.get(connectionId).getTableInfo(dbName, tableName); + if (tableInfo == null || tableInfo.getName().equals("")) { + row.add("0"); + } else { + row.add("1"); + } + resultSetWriter.writeRow(row); + + resultSetWriter.finish(); + return true; + } + return false; + } + private boolean mockSelectCountUnion( int connectionId, ResultSetWriter resultSetWriter, String sql) throws SQLException { // SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = 'demo_db' @@ -713,7 +763,8 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result // Add schema for (int i = 0; i < columnCount; i++) { String columnName = schema.getColumnName(i); - if (sql.equalsIgnoreCase("show table status") && columnName.equalsIgnoreCase("table_id")) { + if ((sql.startsWith("SHOW FULL TABLES") || sql.equalsIgnoreCase("show table status")) + && columnName.equalsIgnoreCase("table_id")) { tableIdColumnIndex = i; continue; } @@ -721,6 +772,9 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result columns.add( new QueryResultColumn(columnName, TypeUtil.openmldbTypeToMysqlTypeString(columnType))); } + if (sql.startsWith("SHOW FULL TABLES")) { + columns.add(new QueryResultColumn("Table_type", "VARCHAR(255)")); + } resultSetWriter.writeColumns(columns); @@ -739,6 +793,9 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result String columnValue = TypeUtil.getResultSetStringColumn(resultSet, i + 1, type); row.add(columnValue); } + if (sql.startsWith("SHOW FULL TABLES")) { + row.add("BASE TABLE"); + } resultSetWriter.writeRow(row); }