From 7cb568bd29884f8c3d31deb4e1af45cfd3ee639b Mon Sep 17 00:00:00 2001 From: August Date: Thu, 12 Dec 2024 18:55:40 +0800 Subject: [PATCH 1/3] feat: support swith database in session --- src/frontend/src/binder/mod.rs | 2 +- src/frontend/src/handler/alter_owner.rs | 4 +- src/frontend/src/handler/alter_parallelism.rs | 4 +- src/frontend/src/handler/alter_rename.rs | 30 ++++----- src/frontend/src/handler/alter_set_schema.rs | 4 +- .../src/handler/alter_source_column.rs | 4 +- .../src/handler/alter_source_with_sr.rs | 4 +- .../src/handler/alter_streaming_rate_limit.rs | 4 +- src/frontend/src/handler/alter_swap_rename.rs | 4 +- .../src/handler/alter_table_column.rs | 4 +- src/frontend/src/handler/alter_user.rs | 2 +- src/frontend/src/handler/close_cursor.rs | 2 +- src/frontend/src/handler/comment.rs | 4 +- src/frontend/src/handler/create_aggregate.rs | 2 +- src/frontend/src/handler/create_connection.rs | 2 +- src/frontend/src/handler/create_database.rs | 2 +- src/frontend/src/handler/create_function.rs | 2 +- src/frontend/src/handler/create_index.rs | 4 +- src/frontend/src/handler/create_mv.rs | 2 +- src/frontend/src/handler/create_schema.rs | 2 +- src/frontend/src/handler/create_secret.rs | 2 +- src/frontend/src/handler/create_sink.rs | 4 +- src/frontend/src/handler/create_source.rs | 2 +- .../src/handler/create_sql_function.rs | 2 +- .../src/handler/create_subscription.rs | 2 +- src/frontend/src/handler/create_table.rs | 8 +-- src/frontend/src/handler/create_user.rs | 6 +- src/frontend/src/handler/create_view.rs | 2 +- src/frontend/src/handler/declare_cursor.rs | 2 +- src/frontend/src/handler/drop_connection.rs | 4 +- src/frontend/src/handler/drop_function.rs | 4 +- src/frontend/src/handler/drop_index.rs | 4 +- src/frontend/src/handler/drop_mv.rs | 6 +- src/frontend/src/handler/drop_schema.rs | 2 +- src/frontend/src/handler/drop_secret.rs | 4 +- src/frontend/src/handler/drop_sink.rs | 4 +- src/frontend/src/handler/drop_source.rs | 4 +- src/frontend/src/handler/drop_subscription.rs | 4 +- src/frontend/src/handler/drop_table.rs | 4 +- src/frontend/src/handler/drop_view.rs | 6 +- src/frontend/src/handler/fetch_cursor.rs | 4 +- src/frontend/src/handler/flush.rs | 2 +- src/frontend/src/handler/handle_privilege.rs | 26 ++++---- src/frontend/src/handler/mod.rs | 22 ++++++- src/frontend/src/handler/privilege.rs | 6 +- src/frontend/src/handler/show.rs | 43 +++++++------ src/frontend/src/handler/use_db.rs | 57 +++++++++++++++++ src/frontend/src/handler/util.rs | 4 +- src/frontend/src/planner/relation.rs | 2 +- src/frontend/src/scheduler/local.rs | 2 +- src/frontend/src/session.rs | 64 ++++++++++--------- src/frontend/src/session/cursor_manager.rs | 2 +- src/frontend/src/test_utils.rs | 2 +- src/frontend/src/utils/with_options.rs | 4 +- src/sqlparser/src/ast/mod.rs | 10 +++ src/sqlparser/src/keywords.rs | 1 + src/sqlparser/src/parser.rs | 6 ++ src/utils/pgwire/src/pg_response.rs | 2 + 58 files changed, 261 insertions(+), 162 deletions(-) create mode 100644 src/frontend/src/handler/use_db.rs diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 82fb74d575e86..b401398485738 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -328,7 +328,7 @@ impl Binder { ) -> Binder { Binder { catalog: session.env().catalog_reader().read_guard(), - db_name: session.database().to_owned(), + db_name: session.database(), session_id: session.id(), context: BindContext::new(), auth_context: session.auth_context(), diff --git a/src/frontend/src/handler/alter_owner.rs b/src/frontend/src/handler/alter_owner.rs index d4b57e32f0ae8..bb6ad749c6b2c 100644 --- a/src/frontend/src/handler/alter_owner.rs +++ b/src/frontend/src/handler/alter_owner.rs @@ -58,11 +58,11 @@ pub async fn handle_alter_owner( stmt_type: StatementType, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_obj_name) = Binder::resolve_schema_qualified_name(db_name, obj_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let new_owner_name = Binder::resolve_user_name(vec![new_owner_name].into())?; diff --git a/src/frontend/src/handler/alter_parallelism.rs b/src/frontend/src/handler/alter_parallelism.rs index 2aeec603f47bd..48bb6c1c76ece 100644 --- a/src/frontend/src/handler/alter_parallelism.rs +++ b/src/frontend/src/handler/alter_parallelism.rs @@ -37,11 +37,11 @@ pub async fn handle_alter_parallelism( deferred: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_table_name) = Binder::resolve_schema_qualified_name(db_name, obj_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let table_id = { diff --git a/src/frontend/src/handler/alter_rename.rs b/src/frontend/src/handler/alter_rename.rs index b68362279ec67..25b3f0b866361 100644 --- a/src/frontend/src/handler/alter_rename.rs +++ b/src/frontend/src/handler/alter_rename.rs @@ -32,12 +32,12 @@ pub async fn handle_rename_table( new_table_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_table_name) = Binder::resolve_schema_qualified_name(db_name, table_name.clone())?; let new_table_name = Binder::resolve_table_name(new_table_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -79,12 +79,12 @@ pub async fn handle_rename_index( new_index_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_index_name) = Binder::resolve_schema_qualified_name(db_name, index_name.clone())?; let new_index_name = Binder::resolve_index_name(new_index_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -113,12 +113,12 @@ pub async fn handle_rename_view( new_view_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_view_name) = Binder::resolve_schema_qualified_name(db_name, view_name.clone())?; let new_view_name = Binder::resolve_view_name(new_view_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -143,12 +143,12 @@ pub async fn handle_rename_sink( new_sink_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_sink_name) = Binder::resolve_schema_qualified_name(db_name, sink_name.clone())?; let new_sink_name = Binder::resolve_sink_name(new_sink_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -176,12 +176,12 @@ pub async fn handle_rename_subscription( new_subscription_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_subscription_name) = Binder::resolve_schema_qualified_name(db_name, subscription_name.clone())?; let new_subscription_name = Binder::resolve_subscription_name(new_subscription_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -210,12 +210,12 @@ pub async fn handle_rename_source( new_source_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_source_name) = Binder::resolve_schema_qualified_name(db_name, source_name.clone())?; let new_source_name = Binder::resolve_source_name(new_source_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -253,7 +253,7 @@ pub async fn handle_rename_schema( new_schema_name: ObjectName, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let schema_name = Binder::resolve_schema_name(schema_name)?; let new_schema_name = Binder::resolve_schema_name(new_schema_name)?; @@ -275,7 +275,7 @@ pub async fn handle_rename_schema( session.check_privilege_for_drop_alter_db_schema(schema)?; // To rename a schema you must also have the CREATE privilege for the database. - if let Some(user) = user_reader.get_user_by_name(session.user_name()) { + if let Some(user) = user_reader.get_user_by_name(&session.user_name()) { if !user.is_super && !user .check_privilege(&grant_privilege::Object::DatabaseId(db_id), AclMode::Create) @@ -321,7 +321,7 @@ pub async fn handle_rename_database( session.check_privilege_for_drop_alter_db_schema(database)?; // Non-superuser owners must also have the CREATEDB privilege. - if let Some(user) = user_reader.get_user_by_name(session.user_name()) { + if let Some(user) = user_reader.get_user_by_name(&session.user_name()) { if !user.is_super && !user.can_create_db { return Err(ErrorCode::PermissionDenied( "Non-superuser owners must also have the CREATEDB privilege".to_owned(), diff --git a/src/frontend/src/handler/alter_set_schema.rs b/src/frontend/src/handler/alter_set_schema.rs index 2edad7adab61f..1d55759294610 100644 --- a/src/frontend/src/handler/alter_set_schema.rs +++ b/src/frontend/src/handler/alter_set_schema.rs @@ -35,11 +35,11 @@ pub async fn handle_alter_set_schema( func_args: Option>, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_obj_name) = Binder::resolve_schema_qualified_name(db_name, obj_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let new_schema_name = Binder::resolve_schema_name(new_schema_name)?; diff --git a/src/frontend/src/handler/alter_source_column.rs b/src/frontend/src/handler/alter_source_column.rs index 3aa4d47f827c1..21f789633f305 100644 --- a/src/frontend/src/handler/alter_source_column.rs +++ b/src/frontend/src/handler/alter_source_column.rs @@ -40,11 +40,11 @@ pub async fn handle_alter_source_column( ) -> Result { // Get original definition let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_source_name) = Binder::resolve_schema_qualified_name(db_name, source_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/alter_source_with_sr.rs b/src/frontend/src/handler/alter_source_with_sr.rs index d4cec17b8b460..a3c3a08e4d42e 100644 --- a/src/frontend/src/handler/alter_source_with_sr.rs +++ b/src/frontend/src/handler/alter_source_with_sr.rs @@ -100,11 +100,11 @@ pub fn fetch_source_catalog_with_db_schema_id( session: &SessionImpl, name: &ObjectName, ) -> Result<(Arc, DatabaseId, SchemaId)> { - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_source_name) = Binder::resolve_schema_qualified_name(db_name, name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/alter_streaming_rate_limit.rs b/src/frontend/src/handler/alter_streaming_rate_limit.rs index d41104f87d488..568acc38b7fb3 100644 --- a/src/frontend/src/handler/alter_streaming_rate_limit.rs +++ b/src/frontend/src/handler/alter_streaming_rate_limit.rs @@ -32,11 +32,11 @@ pub async fn handle_alter_streaming_rate_limit( rate_limit: i32, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_table_name) = Binder::resolve_schema_qualified_name(db_name, table_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/alter_swap_rename.rs b/src/frontend/src/handler/alter_swap_rename.rs index 3301d4a92d796..988b02624b36f 100644 --- a/src/frontend/src/handler/alter_swap_rename.rs +++ b/src/frontend/src/handler/alter_swap_rename.rs @@ -53,11 +53,11 @@ pub async fn handle_swap_rename( stmt_type: StatementType, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (src_schema_name, src_obj_name) = Binder::resolve_schema_qualified_name(db_name, source_object)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let src_schema_path = SchemaPath::new(src_schema_name.as_deref(), &search_path, user_name); let (target_schema_name, target_obj_name) = Binder::resolve_schema_qualified_name(db_name, target_object)?; diff --git a/src/frontend/src/handler/alter_table_column.rs b/src/frontend/src/handler/alter_table_column.rs index 15554e919c77a..5ecc83dd0507b 100644 --- a/src/frontend/src/handler/alter_table_column.rs +++ b/src/frontend/src/handler/alter_table_column.rs @@ -490,11 +490,11 @@ pub fn fetch_table_catalog_for_alter( session: &SessionImpl, table_name: &ObjectName, ) -> Result> { - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_table_name) = Binder::resolve_schema_qualified_name(db_name, table_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 3a7e8193fadf0..4b87dd5c62178 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -171,7 +171,7 @@ pub async fn handle_alter_user( .to_prost(); let session_user = user_reader - .get_user_by_name(session.user_name()) + .get_user_by_name(&session.user_name()) .ok_or_else(|| CatalogError::NotFound("user", session.user_name().to_owned()))?; match stmt.mode { diff --git a/src/frontend/src/handler/close_cursor.rs b/src/frontend/src/handler/close_cursor.rs index 1678b85f85358..e4d01df6fde91 100644 --- a/src/frontend/src/handler/close_cursor.rs +++ b/src/frontend/src/handler/close_cursor.rs @@ -26,7 +26,7 @@ pub async fn handle_close_cursor( ) -> Result { let session = handle_args.session.clone(); let cursor_manager = session.get_cursor_manager(); - let db_name = session.database(); + let db_name = &session.database(); if let Some(cursor_name) = stmt.cursor_name { let (_, cursor_name) = Binder::resolve_schema_qualified_name(db_name, cursor_name.clone())?; cursor_manager.remove_cursor(cursor_name).await?; diff --git a/src/frontend/src/handler/comment.rs b/src/frontend/src/handler/comment.rs index 4140805eb5798..e65e6dad12d46 100644 --- a/src/frontend/src/handler/comment.rs +++ b/src/frontend/src/handler/comment.rs @@ -43,7 +43,7 @@ pub async fn handle_comment( }; let (schema, table) = Binder::resolve_schema_qualified_name( - session.database(), + &session.database(), ObjectName(tab.to_vec()), )?; @@ -64,7 +64,7 @@ pub async fn handle_comment( } CommentObject::Table => { let (schema, table) = - Binder::resolve_schema_qualified_name(session.database(), object_name)?; + Binder::resolve_schema_qualified_name(&session.database(), object_name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema.clone())?; let table = binder.bind_table(schema.as_deref(), &table, None)?; diff --git a/src/frontend/src/handler/create_aggregate.rs b/src/frontend/src/handler/create_aggregate.rs index c9444d75c4573..28fa736f463d6 100644 --- a/src/frontend/src/handler/create_aggregate.rs +++ b/src/frontend/src/handler/create_aggregate.rs @@ -73,7 +73,7 @@ pub async fn handle_create_aggregate( // resolve database and schema id let session = &handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; diff --git a/src/frontend/src/handler/create_connection.rs b/src/frontend/src/handler/create_connection.rs index de06ce439b9bb..8a37b943e2d24 100644 --- a/src/frontend/src/handler/create_connection.rs +++ b/src/frontend/src/handler/create_connection.rs @@ -91,7 +91,7 @@ pub async fn handle_create_connection( stmt: CreateConnectionStatement, ) -> Result { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, connection_name) = Binder::resolve_schema_qualified_name(db_name, stmt.connection_name.clone())?; diff --git a/src/frontend/src/handler/create_database.rs b/src/frontend/src/handler/create_database.rs index 9e7d473045092..f9f1166ff8cb3 100644 --- a/src/frontend/src/handler/create_database.rs +++ b/src/frontend/src/handler/create_database.rs @@ -34,7 +34,7 @@ pub async fn handle_create_database( { let user_reader = session.env().user_info_reader(); let reader = user_reader.read_guard(); - if let Some(info) = reader.get_user_by_name(session.user_name()) { + if let Some(info) = reader.get_user_by_name(&session.user_name()) { if !info.can_create_db && !info.is_super { return Err(PermissionDenied("Do not have the privilege".to_owned()).into()); } diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 36296a9d88cab..73d12b7e6a685 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -106,7 +106,7 @@ pub async fn handle_create_function( // resolve database and schema id let session = &handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; diff --git a/src/frontend/src/handler/create_index.rs b/src/frontend/src/handler/create_index.rs index 22ae2f0633ba7..aab9629c8a6cb 100644 --- a/src/frontend/src/handler/create_index.rs +++ b/src/frontend/src/handler/create_index.rs @@ -47,10 +47,10 @@ pub(crate) fn resolve_index_schema( index_name: ObjectName, table_name: ObjectName, ) -> Result<(String, Arc, String)> { - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, table_name) = Binder::resolve_schema_qualified_name(db_name, table_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let index_table_name = Binder::resolve_index_name(index_name)?; diff --git a/src/frontend/src/handler/create_mv.rs b/src/frontend/src/handler/create_mv.rs index 63e2d35dc4dfc..457fc506c04eb 100644 --- a/src/frontend/src/handler/create_mv.rs +++ b/src/frontend/src/handler/create_mv.rs @@ -108,7 +108,7 @@ pub fn gen_create_mv_plan_bound( context.warn_to_user("The session variable CREATE_COMPACTION_GROUP_FOR_MV has been deprecated. It will not take effect."); } - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, table_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; diff --git a/src/frontend/src/handler/create_schema.rs b/src/frontend/src/handler/create_schema.rs index 818d912d37ebe..6fed546cf9332 100644 --- a/src/frontend/src/handler/create_schema.rs +++ b/src/frontend/src/handler/create_schema.rs @@ -32,7 +32,7 @@ pub async fn handle_create_schema( owner: Option, ) -> Result { let session = handler_args.session; - let database_name = session.database(); + let database_name = &session.database(); let schema_name = Binder::resolve_schema_name(schema_name)?; if schema_name.starts_with(RESERVED_PG_SCHEMA_PREFIX) { diff --git a/src/frontend/src/handler/create_secret.rs b/src/frontend/src/handler/create_secret.rs index a2ad00e007e89..158d8f177a963 100644 --- a/src/frontend/src/handler/create_secret.rs +++ b/src/frontend/src/handler/create_secret.rs @@ -36,7 +36,7 @@ pub async fn handle_create_secret( .map_err(|e| anyhow::anyhow!(e))?; let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, stmt.secret_name.clone())?; diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index c4d6793444104..ddd1e4b17ab62 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -106,7 +106,7 @@ pub async fn gen_sink_plan( let session = handler_args.session.clone(); let session = session.as_ref(); let user_specified_columns = !stmt.columns.is_empty(); - let db_name = session.database(); + let db_name = &session.database(); let (sink_schema_name, sink_table_name) = Binder::resolve_schema_qualified_name(db_name, stmt.sink_name.clone())?; @@ -563,7 +563,7 @@ pub fn fetch_incoming_sinks( ) -> Result>> { let reader = session.env().catalog_reader().read_guard(); let mut sinks = Vec::with_capacity(incoming_sink_ids.len()); - let db_name = session.database(); + let db_name = &session.database(); for schema in reader.iter_schemas(db_name)? { for sink in schema.iter_sink() { if incoming_sink_ids.contains(&sink.id.sink_id) { diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 09aebe2be26f0..fc46ae18ea540 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -1548,7 +1548,7 @@ pub async fn bind_create_source_or_table_with_connector( source_rate_limit: Option, ) -> Result<(SourceCatalog, DatabaseId, SchemaId)> { let session = &handler_args.session; - let db_name: &str = session.database(); + let db_name: &str = &session.database(); let (schema_name, source_name) = Binder::resolve_schema_qualified_name(db_name, full_name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name.clone())?; diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index c90d7d8770d8b..d7e3882b8ad75 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -213,7 +213,7 @@ pub async fn handle_create_sql_function( // resolve database and schema id let session = &handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; diff --git a/src/frontend/src/handler/create_subscription.rs b/src/frontend/src/handler/create_subscription.rs index 8d4ed82cc82ee..3392c0874d6be 100644 --- a/src/frontend/src/handler/create_subscription.rs +++ b/src/frontend/src/handler/create_subscription.rs @@ -31,7 +31,7 @@ pub fn create_subscription_catalog( context: OptimizerContextRef, stmt: CreateSubscriptionStatement, ) -> Result { - let db_name = session.database(); + let db_name = &session.database(); let (subscription_schema_name, subscription_name) = Binder::resolve_schema_qualified_name(db_name, stmt.subscription_name.clone())?; let (table_schema_name, subscription_from_table_name) = diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index fc2b4bc833cd4..5a9947bb52da0 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -620,7 +620,7 @@ pub(crate) fn gen_create_table_plan_without_source( )?; let session = context.session_ctx().clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, name) = Binder::resolve_schema_qualified_name(db_name, table_name)?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name.clone())?; @@ -1043,7 +1043,7 @@ pub(super) async fn handle_create_table_plan( )?; let session = &handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, resolved_table_name) = Binder::resolve_schema_qualified_name(db_name, table_name.clone())?; let (database_id, schema_id) = @@ -1871,7 +1871,7 @@ fn get_source_and_resolved_table_name( cdc_table: CdcTableInfo, table_name: ObjectName, ) -> Result<(Arc, String, DatabaseId, SchemaId)> { - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, resolved_table_name) = Binder::resolve_schema_qualified_name(db_name, table_name)?; let (database_id, schema_id) = @@ -1914,7 +1914,7 @@ fn bind_webhook_info( } = webhook_info; // validate secret_ref - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, secret_ref.secret_name.clone())?; let secret_catalog = session.get_secret_by_name(schema_name, &secret_name)?; diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 71a08eb927efa..142d745f58826 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -116,7 +116,7 @@ pub async fn handle_create_user( let database_id = { let catalog_reader = session.env().catalog_reader().read_guard(); catalog_reader - .get_database_by_name(session.database()) + .get_database_by_name(&session.database()) .expect("session database should exist") .id() }; @@ -128,8 +128,8 @@ pub async fn handle_create_user( } let session_user = user_reader - .get_user_by_name(session.user_name()) - .ok_or_else(|| CatalogError::NotFound("user", session.user_name().to_owned()))?; + .get_user_by_name(&session.user_name()) + .ok_or_else(|| CatalogError::NotFound("user", session.user_name()))?; make_prost_user_info(user_name, &stmt.with_options, session_user, database_id)? }; diff --git a/src/frontend/src/handler/create_view.rs b/src/frontend/src/handler/create_view.rs index eb58b0c3f753f..acac5eb2b58b3 100644 --- a/src/frontend/src/handler/create_view.rs +++ b/src/frontend/src/handler/create_view.rs @@ -35,7 +35,7 @@ pub async fn handle_create_view( query: Query, ) -> Result { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, view_name) = Binder::resolve_schema_qualified_name(db_name, name.clone())?; let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; diff --git a/src/frontend/src/handler/declare_cursor.rs b/src/frontend/src/handler/declare_cursor.rs index dd745d6517d5a..0423bb7ac7205 100644 --- a/src/frontend/src/handler/declare_cursor.rs +++ b/src/frontend/src/handler/declare_cursor.rs @@ -58,7 +58,7 @@ async fn handle_declare_subscription_cursor( rw_timestamp: Since, ) -> Result { let session = handle_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, cursor_name) = Binder::resolve_schema_qualified_name(db_name, cursor_name.clone())?; diff --git a/src/frontend/src/handler/drop_connection.rs b/src/frontend/src/handler/drop_connection.rs index b90ae44990740..4172ea481fe61 100644 --- a/src/frontend/src/handler/drop_connection.rs +++ b/src/frontend/src/handler/drop_connection.rs @@ -27,11 +27,11 @@ pub async fn handle_drop_connection( if_exists: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, connection_name) = Binder::resolve_schema_qualified_name(db_name, connection_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/drop_function.rs b/src/frontend/src/handler/drop_function.rs index 945cf4816bf9e..feca1ea7b29f3 100644 --- a/src/frontend/src/handler/drop_function.rs +++ b/src/frontend/src/handler/drop_function.rs @@ -36,11 +36,11 @@ pub async fn handle_drop_function( let func_desc = func_desc.remove(0); let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, func_desc.name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let arg_types = match func_desc.args { diff --git a/src/frontend/src/handler/drop_index.rs b/src/frontend/src/handler/drop_index.rs index 3ff2de4800762..b9c582ee921e4 100644 --- a/src/frontend/src/handler/drop_index.rs +++ b/src/frontend/src/handler/drop_index.rs @@ -31,10 +31,10 @@ pub async fn handle_drop_index( cascade: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, index_name) = Binder::resolve_schema_qualified_name(db_name, index_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let index_id = { diff --git a/src/frontend/src/handler/drop_mv.rs b/src/frontend/src/handler/drop_mv.rs index 9ec6a56d20d3b..4d58a0baae99f 100644 --- a/src/frontend/src/handler/drop_mv.rs +++ b/src/frontend/src/handler/drop_mv.rs @@ -32,17 +32,17 @@ pub async fn handle_drop_mv( cascade: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, table_name) = Binder::resolve_schema_qualified_name(db_name, table_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let (table_id, status) = { let reader = session.env().catalog_reader().read_guard(); let (table, schema_name) = - match reader.get_any_table_by_name(session.database(), schema_path, &table_name) { + match reader.get_any_table_by_name(&session.database(), schema_path, &table_name) { Ok((t, s)) => (t, s), Err(e) => { return if if_exists { diff --git a/src/frontend/src/handler/drop_schema.rs b/src/frontend/src/handler/drop_schema.rs index 4a958bc71ed0f..227eada423a16 100644 --- a/src/frontend/src/handler/drop_schema.rs +++ b/src/frontend/src/handler/drop_schema.rs @@ -42,7 +42,7 @@ pub async fn handle_drop_schema( let schema = { let reader = catalog_reader.read_guard(); - match reader.get_schema_by_name(session.database(), &schema_name) { + match reader.get_schema_by_name(&session.database(), &schema_name) { Ok(schema) => schema.clone(), Err(err) => { // If `if_exist` is true, not return error. diff --git a/src/frontend/src/handler/drop_secret.rs b/src/frontend/src/handler/drop_secret.rs index 4720d73bfa7e6..834a7eb6c50b4 100644 --- a/src/frontend/src/handler/drop_secret.rs +++ b/src/frontend/src/handler/drop_secret.rs @@ -60,11 +60,11 @@ pub fn fetch_secret_catalog_with_db_schema_id( secret_name: &ObjectName, if_exists: bool, ) -> Result, DatabaseId, SchemaId)>> { - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, secret_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/drop_sink.rs b/src/frontend/src/handler/drop_sink.rs index ab98894330b69..b6331c7eb59de 100644 --- a/src/frontend/src/handler/drop_sink.rs +++ b/src/frontend/src/handler/drop_sink.rs @@ -33,10 +33,10 @@ pub async fn handle_drop_sink( cascade: bool, ) -> Result { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, sink_name) = Binder::resolve_schema_qualified_name(db_name, sink_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let sink = { diff --git a/src/frontend/src/handler/drop_source.rs b/src/frontend/src/handler/drop_source.rs index 711cf2d7caec0..689ea84d867f8 100644 --- a/src/frontend/src/handler/drop_source.rs +++ b/src/frontend/src/handler/drop_source.rs @@ -28,10 +28,10 @@ pub async fn handle_drop_source( cascade: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, source_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); // Check if temporary source exists, if yes, drop it. if let Some(_source) = session.get_temporary_source(&source_name) { diff --git a/src/frontend/src/handler/drop_subscription.rs b/src/frontend/src/handler/drop_subscription.rs index 84c6165604a02..ed3dd9bd66d08 100644 --- a/src/frontend/src/handler/drop_subscription.rs +++ b/src/frontend/src/handler/drop_subscription.rs @@ -27,11 +27,11 @@ pub async fn handle_drop_subscription( cascade: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, subscription_name) = Binder::resolve_schema_qualified_name(db_name, subscription_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let subscription = { diff --git a/src/frontend/src/handler/drop_table.rs b/src/frontend/src/handler/drop_table.rs index b3ef634b7243a..d90cfefb5e706 100644 --- a/src/frontend/src/handler/drop_table.rs +++ b/src/frontend/src/handler/drop_table.rs @@ -35,10 +35,10 @@ pub async fn handle_drop_table( cascade: bool, ) -> Result { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, table_name) = Binder::resolve_schema_qualified_name(db_name, table_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); diff --git a/src/frontend/src/handler/drop_view.rs b/src/frontend/src/handler/drop_view.rs index b24760c36c0c2..ae93ec79578cf 100644 --- a/src/frontend/src/handler/drop_view.rs +++ b/src/frontend/src/handler/drop_view.rs @@ -28,17 +28,17 @@ pub async fn handle_drop_view( cascade: bool, ) -> Result { let session = handler_args.session; - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, table_name) = Binder::resolve_schema_qualified_name(db_name, table_name)?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let view_id = { let reader = session.env().catalog_reader().read_guard(); let (view, schema_name) = - match reader.get_view_by_name(session.database(), schema_path, &table_name) { + match reader.get_view_by_name(&session.database(), schema_path, &table_name) { Ok((t, s)) => (t, s), Err(e) => { return if if_exists { diff --git a/src/frontend/src/handler/fetch_cursor.rs b/src/frontend/src/handler/fetch_cursor.rs index 7f0b88826fabe..1fe29cb98956d 100644 --- a/src/frontend/src/handler/fetch_cursor.rs +++ b/src/frontend/src/handler/fetch_cursor.rs @@ -58,7 +58,7 @@ pub async fn handle_fetch_cursor( formats: &Vec, ) -> Result { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (_, cursor_name) = Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?; @@ -105,7 +105,7 @@ pub async fn handle_parse( ) -> Result { if let Statement::FetchCursor { stmt } = &statement { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (_, cursor_name) = Binder::resolve_schema_qualified_name(db_name, stmt.cursor_name.clone())?; let fields = session diff --git a/src/frontend/src/handler/flush.rs b/src/frontend/src/handler/flush.rs index 666dca64583ba..866badd5176d9 100644 --- a/src/frontend/src/handler/flush.rs +++ b/src/frontend/src/handler/flush.rs @@ -30,7 +30,7 @@ pub(crate) async fn do_flush(session: &SessionImpl) -> Result<()> { .env() .catalog_reader() .read_guard() - .get_database_by_name(session.database())? + .get_database_by_name(&session.database())? .id(); let version_id = client.flush(database_id).await?; diff --git a/src/frontend/src/handler/handle_privilege.rs b/src/frontend/src/handler/handle_privilege.rs index 08f2a25872845..3981344c65295 100644 --- a/src/frontend/src/handler/handle_privilege.rs +++ b/src/frontend/src/handler/handle_privilege.rs @@ -57,14 +57,14 @@ fn make_prost_privilege( GrantObjects::Schemas(schemas) => { for schema in schemas { let schema_name = Binder::resolve_schema_name(schema)?; - let schema = reader.get_schema_by_name(session.database(), &schema_name)?; + let schema = reader.get_schema_by_name(&session.database(), &schema_name)?; grant_objs.push(PbObject::SchemaId(schema.id())); } } GrantObjects::Mviews(tables) => { - let db_name = session.database(); + let db_name = &session.database(); let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); for name in tables { let (schema_name, table_name) = @@ -86,9 +86,9 @@ fn make_prost_privilege( } } GrantObjects::Tables(tables) => { - let db_name = session.database(); + let db_name = &session.database(); let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); for name in tables { let (schema_name, table_name) = @@ -123,9 +123,9 @@ fn make_prost_privilege( } } GrantObjects::Sources(sources) => { - let db_name = session.database(); + let db_name = &session.database(); let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); for name in sources { let (schema_name, source_name) = @@ -137,9 +137,9 @@ fn make_prost_privilege( } } GrantObjects::Sinks(sinks) => { - let db_name = session.database(); + let db_name = &session.database(); let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); for name in sinks { let (schema_name, sink_name) = @@ -153,21 +153,21 @@ fn make_prost_privilege( GrantObjects::AllSourcesInSchema { schemas } => { for schema in schemas { let schema_name = Binder::resolve_schema_name(schema)?; - let schema = reader.get_schema_by_name(session.database(), &schema_name)?; + let schema = reader.get_schema_by_name(&session.database(), &schema_name)?; grant_objs.push(PbObject::AllSourcesSchemaId(schema.id())); } } GrantObjects::AllMviewsInSchema { schemas } => { for schema in schemas { let schema_name = Binder::resolve_schema_name(schema)?; - let schema = reader.get_schema_by_name(session.database(), &schema_name)?; + let schema = reader.get_schema_by_name(&session.database(), &schema_name)?; grant_objs.push(PbObject::AllTablesSchemaId(schema.id())); } } GrantObjects::AllTablesInSchema { schemas } => { for schema in schemas { let schema_name = Binder::resolve_schema_name(schema)?; - let schema = reader.get_schema_by_name(session.database(), &schema_name)?; + let schema = reader.get_schema_by_name(&session.database(), &schema_name)?; grant_objs.push(PbObject::AllDmlRelationsSchemaId(schema.id())); } } @@ -323,7 +323,7 @@ mod tests { let reader = catalog_reader.read_guard(); ( reader - .get_database_by_name(session.database()) + .get_database_by_name(&session.database()) .unwrap() .id(), reader.get_database_by_name("db1").unwrap().id(), diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index b48be779df40b..dba4aba506499 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -99,6 +99,7 @@ pub mod query; mod recover; pub mod show; mod transaction; +mod use_db; pub mod util; pub mod variable; mod wait; @@ -586,7 +587,25 @@ pub async fn handle( local: _, variable, value, - } => variable::handle_set(handler_args, variable, value), + } => { + // special handle for `use database` + if variable.real_value().eq_ignore_ascii_case("database") { + let x = variable::set_var_to_param_str(&value); + let res = use_db::handle_use_db( + handler_args, + ObjectName::from(vec![Ident::new_unchecked( + x.unwrap_or("default".to_string()), + )]), + ) + .await?; + let mut builder = RwPgResponse::builder(StatementType::SET_VARIABLE); + for notice in res.notices() { + builder = builder.notice(notice); + } + return Ok(builder.into()); + } + variable::handle_set(handler_args, variable, value) + } Statement::SetTimeZone { local: _, value } => { variable::handle_set_time_zone(handler_args, value) } @@ -1125,6 +1144,7 @@ pub async fn handle( object_name, comment, } => comment::handle_comment(handler_args, object_type, object_name, comment).await, + Statement::Use { db_name } => use_db::handle_use_db(handler_args, db_name).await, _ => bail_not_implemented!("Unhandled statement: {}", stmt), } } diff --git a/src/frontend/src/handler/privilege.rs b/src/frontend/src/handler/privilege.rs index 1aff2e69f682f..5cb345aa16855 100644 --- a/src/frontend/src/handler/privilege.rs +++ b/src/frontend/src/handler/privilege.rs @@ -143,7 +143,7 @@ impl SessionImpl { let user_reader = self.env().user_info_reader(); let reader = user_reader.read_guard(); - if let Some(user) = reader.get_user_by_name(self.user_name()) { + if let Some(user) = reader.get_user_by_name(&self.user_name()) { if user.is_super { return Ok(()); } @@ -167,7 +167,7 @@ impl SessionImpl { pub fn is_super_user(&self) -> bool { let reader = self.env().user_info_reader().read_guard(); - if let Some(info) = reader.get_user_by_name(self.user_name()) { + if let Some(info) = reader.get_user_by_name(&self.user_name()) { info.is_super } else { false @@ -193,7 +193,7 @@ impl SessionImpl { .env() .catalog_reader() .read_guard() - .get_schema_by_name(self.database(), schema_name) + .get_schema_by_name(&self.database(), schema_name) .unwrap() .owner(); diff --git a/src/frontend/src/handler/show.rs b/src/frontend/src/handler/show.rs index c73b0b37e25ac..c771cb5bcb2f7 100644 --- a/src/frontend/src/handler/show.rs +++ b/src/frontend/src/handler/show.rs @@ -121,7 +121,7 @@ fn schema_or_search_path( .iter() .map(|s| { if s.eq(USER_NAME_WILD_CARD) { - session.auth_context().user_name.clone() + session.user_name() } else { s.to_string() } @@ -315,7 +315,7 @@ pub async fn handle_show_object( // If the schema is not found, skip it if let Ok(schema_catalog) = catalog_reader .read_guard() - .get_schema_by_name(session.database(), schema.as_ref()) + .get_schema_by_name(&session.database(), schema.as_ref()) { table_names_in_schema .extend(schema_catalog.iter_user_table().map(|t| t.name.clone())); @@ -326,43 +326,43 @@ pub async fn handle_show_object( } ShowObject::InternalTable { schema } => catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_internal_table() .map(|t| t.name.clone()) .collect(), ShowObject::Database => catalog_reader.read_guard().get_all_database_names(), ShowObject::Schema => catalog_reader .read_guard() - .get_all_schema_names(session.database())?, + .get_all_schema_names(&session.database())?, ShowObject::View { schema } => catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_view() .map(|t| t.name.clone()) .collect(), ShowObject::MaterializedView { schema } => catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_created_mvs() .map(|t| t.name.clone()) .collect(), ShowObject::Source { schema } => catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_source() .map(|t| t.name.clone()) .chain(session.temporary_source_manager().keys()) .collect(), ShowObject::Sink { schema } => catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_sink() .map(|t| t.name.clone()) .collect(), ShowObject::Subscription { schema } => { let rows = catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_subscription() .map(|t| ShowSubscriptionRow { name: t.name.clone(), @@ -375,7 +375,7 @@ pub async fn handle_show_object( } ShowObject::Secret { schema } => catalog_reader .read_guard() - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_secret() .map(|t| t.name.clone()) .collect(), @@ -405,7 +405,7 @@ pub async fn handle_show_object( ShowObject::Connection { schema } => { let reader = catalog_reader.read_guard(); let schema = - reader.get_schema_by_name(session.database(), &schema_or_default(&schema))?; + reader.get_schema_by_name(&session.database(), &schema_or_default(&schema))?; let rows = schema .iter_connections() .map(|c| { @@ -460,7 +460,7 @@ pub async fn handle_show_object( ShowObject::Function { schema } => { let reader = catalog_reader.read_guard(); let rows = reader - .get_schema_by_name(session.database(), &schema_or_default(&schema))? + .get_schema_by_name(&session.database(), &schema_or_default(&schema))? .iter_function() .map(|t| ShowFunctionRow { name: t.name.clone(), @@ -513,9 +513,9 @@ pub async fn handle_show_object( ShowProcessListRow { // Since process id and the secret id in the session id are the same in RisingWave, just display the process id. id: format!("{}", s.id().0), - user: s.user_name().to_owned(), + user: s.user_name(), host: format!("{}", s.peer_addr()), - database: s.database().to_owned(), + database: s.database(), time: s .elapse_since_running_sql() .map(|mills| format!("{}ms", mills)), @@ -540,9 +540,9 @@ pub async fn handle_show_object( let mut rows = vec![]; for s in sessions { let session_id = format!("{}", s.id().0); - let user = s.user_name().to_owned(); + let user = s.user_name(); let host = format!("{}", s.peer_addr()); - let database = s.database().to_owned(); + let database = s.database(); s.get_cursor_manager() .iter_query_cursors(|cursor_name: &String, _| { @@ -570,8 +570,8 @@ pub async fn handle_show_object( .collect_vec(); let mut rows = vec![]; for s in sessions { - let ssession_id = format!("{}", s.id().0); - let user = s.user_name().to_owned(); + let session_id = format!("{}", s.id().0); + let user = s.user_name(); let host = format!("{}", s.peer_addr()); let database = s.database().to_owned(); @@ -579,7 +579,7 @@ pub async fn handle_show_object( .iter_subscription_cursors( |cursor_name: &String, cursor: &SubscriptionCursor| { rows.push(ShowSubscriptionCursorRow { - session_id: ssession_id.clone(), + session_id: session_id.clone(), user: user.clone(), host: host.clone(), database: database.clone(), @@ -625,10 +625,11 @@ pub fn handle_show_create_object( ) -> Result { let session = handle_args.session; let catalog_reader = session.env().catalog_reader().read_guard(); + let database = session.database(); let (schema_name, object_name) = - Binder::resolve_schema_qualified_name(session.database(), name.clone())?; + Binder::resolve_schema_qualified_name(&database, name.clone())?; let schema_name = schema_name.unwrap_or(DEFAULT_SCHEMA_NAME.to_owned()); - let schema = catalog_reader.get_schema_by_name(session.database(), &schema_name)?; + let schema = catalog_reader.get_schema_by_name(&database, &schema_name)?; let sql = match show_create_type { ShowCreateType::MaterializedView => { let mv = schema diff --git a/src/frontend/src/handler/use_db.rs b/src/frontend/src/handler/use_db.rs new file mode 100644 index 0000000000000..dcffe310bfd57 --- /dev/null +++ b/src/frontend/src/handler/use_db.rs @@ -0,0 +1,57 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pgwire::pg_response::StatementType; +use risingwave_common::acl::AclMode; +use risingwave_common::session_config::SessionConfig; +use risingwave_pb::user::grant_privilege::Object as GrantObject; +use risingwave_sqlparser::ast::ObjectName; + +use crate::error::Result; +use crate::handler::privilege::ObjectCheckItem; +use crate::handler::{HandlerArgs, RwPgResponse}; +use crate::Binder; + +pub async fn handle_use_db( + handler_args: HandlerArgs, + database_name: ObjectName, +) -> Result { + let session = handler_args.session; + let database_name = Binder::resolve_database_name(database_name)?; + + let (database_id, owner_id) = { + let catalog_reader = session.env().catalog_reader(); + let reader = catalog_reader.read_guard(); + let db = reader.get_database_by_name(&database_name)?; + (db.id(), db.owner) + }; + session.check_privileges(&[ObjectCheckItem::new( + owner_id, + AclMode::Connect, + GrantObject::DatabaseId(database_id), + )])?; + + let mut builder = RwPgResponse::builder(StatementType::USE); + builder = builder.notice(format!( + "You are now connected to database \"{}\" as user \"{}\".", + database_name, + session.user_name() + )); + + // reset session config + *session.shared_config().write() = SessionConfig::default(); + session.update_database(database_name); + + Ok(builder.into()) +} diff --git a/src/frontend/src/handler/util.rs b/src/frontend/src/handler/util.rs index e54fa7fb722c4..483187e4da5e1 100644 --- a/src/frontend/src/handler/util.rs +++ b/src/frontend/src/handler/util.rs @@ -315,11 +315,11 @@ pub fn get_table_catalog_by_table_name( session: &SessionImpl, table_name: &ObjectName, ) -> RwResult<(Arc, String)> { - let db_name = session.database(); + let db_name = &session.database(); let (schema_name, real_table_name) = Binder::resolve_schema_qualified_name(db_name, table_name.clone())?; let search_path = session.config().search_path(); - let user_name = &session.auth_context().user_name; + let user_name = &session.user_name(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); let reader = session.env().catalog_reader().read_guard(); diff --git a/src/frontend/src/planner/relation.rs b/src/frontend/src/planner/relation.rs index 7a08502f8519c..d3bd010a0bc98 100644 --- a/src/frontend/src/planner/relation.rs +++ b/src/frontend/src/planner/relation.rs @@ -135,7 +135,7 @@ impl Planner { } let opt_ctx = self.ctx(); let session = opt_ctx.session_ctx(); - let db_name = session.database(); + let db_name = &session.database(); let catalog_reader = session.env().catalog_reader().read_guard(); let mut source_catalog = None; for schema in catalog_reader.iter_schemas(db_name).unwrap() { diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index afd0cfcabaa33..7f7d0fabd6300 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -144,7 +144,7 @@ impl LocalQueryExecution { let catalog_reader = self.front_env.catalog_reader().clone(); let user_info_reader = self.front_env.user_info_reader().clone(); let auth_context = self.session.auth_context().clone(); - let db_name = self.session.database().to_owned(); + let db_name = self.session.database(); let search_path = self.session.config().search_path(); let time_zone = self.session.config().timezone(); let strict_mode = self.session.config().batch_expr_strict_mode(); diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 36f6a5dc12e17..723f4e5407b61 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -623,6 +623,7 @@ impl FrontendEnv { } } +#[derive(Clone)] pub struct AuthContext { pub database: String, pub user_name: String, @@ -640,7 +641,7 @@ impl AuthContext { } pub struct SessionImpl { env: FrontendEnv, - auth_context: Arc, + auth_context: Arc>, /// Used for user authentication. user_authenticator: UserAuthenticator, /// Stores the value of configurations. @@ -733,7 +734,7 @@ impl From for RwError { impl SessionImpl { pub fn new( env: FrontendEnv, - auth_context: Arc, + auth_context: AuthContext, user_authenticator: UserAuthenticator, id: SessionId, peer_addr: AddressRef, @@ -742,7 +743,7 @@ impl SessionImpl { let cursor_metrics = env.cursor_metrics.clone(); Self { env, - auth_context, + auth_context: Arc::new(RwLock::new(auth_context)), user_authenticator, config_map: Arc::new(RwLock::new(session_config)), id, @@ -762,11 +763,11 @@ impl SessionImpl { let env = FrontendEnv::mock(); Self { env: FrontendEnv::mock(), - auth_context: Arc::new(AuthContext::new( + auth_context: Arc::new(RwLock::new(AuthContext::new( DEFAULT_DATABASE_NAME.to_owned(), DEFAULT_SUPER_USER.to_owned(), DEFAULT_SUPER_USER_ID, - )), + ))), user_authenticator: UserAuthenticator::None, config_map: Default::default(), // Mock session use non-sense id. @@ -791,19 +792,24 @@ impl SessionImpl { } pub fn auth_context(&self) -> Arc { - self.auth_context.clone() + let ctx = self.auth_context.read(); + Arc::new(ctx.clone()) } - pub fn database(&self) -> &str { - &self.auth_context.database + pub fn database(&self) -> String { + self.auth_context.read().database.clone() } - pub fn user_name(&self) -> &str { - &self.auth_context.user_name + pub fn user_name(&self) -> String { + self.auth_context.read().user_name.clone() } pub fn user_id(&self) -> UserId { - self.auth_context.user_id + self.auth_context.read().user_id + } + + pub fn update_database(&self, database: String) { + self.auth_context.write().database = database; } pub fn shared_config(&self) -> Arc> { @@ -881,13 +887,13 @@ impl SessionImpl { stmt_type: StatementType, if_not_exists: bool, ) -> std::result::Result, CheckRelationError> { - let db_name = self.database(); + let db_name = &self.database(); let catalog_reader = self.env().catalog_reader().read_guard(); let (schema_name, relation_name) = { let (schema_name, relation_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let schema_name = match schema_name { Some(schema_name) => schema_name, None => catalog_reader @@ -908,12 +914,12 @@ impl SessionImpl { } pub fn check_secret_name_duplicated(&self, name: ObjectName) -> Result<()> { - let db_name = self.database(); + let db_name = &self.database(); let catalog_reader = self.env().catalog_reader().read_guard(); let (schema_name, secret_name) = { let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let schema_name = match schema_name { Some(schema_name) => schema_name, None => catalog_reader @@ -928,13 +934,13 @@ impl SessionImpl { } pub fn check_connection_name_duplicated(&self, name: ObjectName) -> Result<()> { - let db_name = self.database(); + let db_name = &self.database(); let catalog_reader = self.env().catalog_reader().read_guard(); let (schema_name, connection_name) = { let (schema_name, connection_name) = Binder::resolve_schema_qualified_name(db_name, name)?; let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let schema_name = match schema_name { Some(schema_name) => schema_name, None => catalog_reader @@ -953,10 +959,10 @@ impl SessionImpl { &self, schema_name: Option, ) -> Result<(DatabaseId, SchemaId)> { - let db_name = self.database(); + let db_name = &self.database(); let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let catalog_reader = self.env().catalog_reader().read_guard(); let schema = match schema_name { @@ -982,9 +988,9 @@ impl SessionImpl { schema_name: Option, connection_name: &str, ) -> Result> { - let db_name = self.database(); + let db_name = &self.database(); let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let catalog_reader = self.env().catalog_reader().read_guard(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -998,7 +1004,7 @@ impl SessionImpl { schema_id: SchemaId, subscription_name: &str, ) -> Result> { - let db_name = self.database(); + let db_name = &self.database(); let catalog_reader = self.env().catalog_reader().read_guard(); let db_id = catalog_reader.get_database_by_name(db_name)?.id(); @@ -1019,9 +1025,9 @@ impl SessionImpl { schema_name: Option, subscription_name: &str, ) -> Result> { - let db_name = self.database(); + let db_name = &self.database(); let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let catalog_reader = self.env().catalog_reader().read_guard(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -1059,9 +1065,9 @@ impl SessionImpl { schema_name: Option, secret_name: &str, ) -> Result> { - let db_name = self.database(); + let db_name = &self.database(); let search_path = self.config().search_path(); - let user_name = &self.auth_context().user_name; + let user_name = &self.user_name(); let catalog_reader = self.env().catalog_reader().read_guard(); let schema_path = SchemaPath::new(schema_name.as_deref(), &search_path, user_name); @@ -1415,11 +1421,7 @@ impl SessionManagerImpl { let session_impl: Arc = SessionImpl::new( self.env.clone(), - Arc::new(AuthContext::new( - database_name.to_owned(), - user_name.to_owned(), - user.id, - )), + AuthContext::new(database_name.to_owned(), user_name.to_owned(), user.id), user_authenticator, id, peer_addr, diff --git a/src/frontend/src/session/cursor_manager.rs b/src/frontend/src/session/cursor_manager.rs index 723e82db659da..4f011687c7b2b 100644 --- a/src/frontend/src/session/cursor_manager.rs +++ b/src/frontend/src/session/cursor_manager.rs @@ -1031,7 +1031,7 @@ impl CursorManager { handler_args: HandlerArgs, ) -> Result { let session = handler_args.session.clone(); - let db_name = session.database(); + let db_name = &session.database(); let (_, cursor_name) = Binder::resolve_schema_qualified_name(db_name, cursor_name.clone())?; match self.cursor_map.lock().await.get(&cursor_name).ok_or_else(|| { ErrorCode::InternalError(format!("Cannot find cursor `{}`", cursor_name)) diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index 41f17b566019b..c602a6aadad90 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -203,7 +203,7 @@ impl LocalFrontend { ) -> Arc { Arc::new(SessionImpl::new( self.env.clone(), - Arc::new(AuthContext::new(database, user_name, user_id)), + AuthContext::new(database, user_name, user_id), UserAuthenticator::None, // Local Frontend use a non-sense id. (0, 0), diff --git a/src/frontend/src/utils/with_options.rs b/src/frontend/src/utils/with_options.rs index 7080dcf7dad78..934d2ecc23815 100644 --- a/src/frontend/src/utils/with_options.rs +++ b/src/frontend/src/utils/with_options.rs @@ -199,7 +199,7 @@ pub(crate) fn resolve_connection_ref_and_secret_ref( object: TelemetryDatabaseObject, ) -> RwResult<(WithOptionsSecResolved, PbConnectionType, Option)> { let connector_name = with_options.get_connector(); - let db_name: &str = session.database(); + let db_name: &str = &session.database(); let (mut options, secret_refs, connection_refs) = with_options.clone().into_parts(); let mut connection_id = None; @@ -339,7 +339,7 @@ pub(crate) fn resolve_secret_ref_in_with_options( ) -> RwResult { let (options, secret_refs, _) = with_options.into_parts(); let mut resolved_secret_refs = BTreeMap::new(); - let db_name: &str = session.database(); + let db_name: &str = &session.database(); for (key, secret_ref) in secret_refs { let (schema_name, secret_name) = Binder::resolve_schema_qualified_name(db_name, secret_ref.secret_name.clone())?; diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index c0fce4a2d4780..a8a9cc8a5da9e 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -1621,6 +1621,12 @@ pub enum Statement { Wait, /// Trigger stream job recover Recover, + /// `USE ` + /// + /// Note: this is a RisingWave specific statement and used to switch the current database. + Use { + db_name: ObjectName, + }, } impl fmt::Display for Statement { @@ -2230,6 +2236,10 @@ impl fmt::Display for Statement { write!(f, "RECOVER")?; Ok(()) } + Statement::Use { db_name } => { + write!(f, "USE {}", db_name)?; + Ok(()) + } } } } diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index 84f3d0eb42a68..a05e46068d94b 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -553,6 +553,7 @@ define_keywords!( UPDATE, UPPER, USAGE, + USE, USER, USING, UUID, diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index f9ee9f7275433..adf200f9a7ffd 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -331,6 +331,7 @@ impl Parser<'_> { Keyword::FLUSH => Ok(Statement::Flush), Keyword::WAIT => Ok(Statement::Wait), Keyword::RECOVER => Ok(Statement::Recover), + Keyword::USE => Ok(self.parse_use()?), _ => self.expected_at(checkpoint, "statement"), }, Token::LParen => { @@ -5531,6 +5532,11 @@ impl Parser<'_> { comment, }) } + + fn parse_use(&mut self) -> PResult { + let db_name = self.parse_object_name()?; + Ok(Statement::Use { db_name }) + } } impl Word { diff --git a/src/utils/pgwire/src/pg_response.rs b/src/utils/pgwire/src/pg_response.rs index b46f0b3fabcea..c0b2e3dfd32aa 100644 --- a/src/utils/pgwire/src/pg_response.rs +++ b/src/utils/pgwire/src/pg_response.rs @@ -113,6 +113,7 @@ pub enum StatementType { WAIT, KILL, RECOVER, + USE, } impl std::fmt::Display for StatementType { @@ -322,6 +323,7 @@ impl StatementType { Statement::CloseCursor { .. } => Ok(StatementType::CLOSE_CURSOR), Statement::Flush => Ok(StatementType::FLUSH), Statement::Wait => Ok(StatementType::WAIT), + Statement::Use { .. } => Ok(StatementType::USE), _ => Err("unsupported statement type".to_owned()), } } From 8d08c4b61a0b0dd466f770949adc00e322d44d65 Mon Sep 17 00:00:00 2001 From: August Date: Tue, 24 Dec 2024 16:24:43 +0800 Subject: [PATCH 2/3] reset search path only --- src/frontend/src/handler/use_db.rs | 6 +++--- src/frontend/src/session.rs | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/frontend/src/handler/use_db.rs b/src/frontend/src/handler/use_db.rs index dcffe310bfd57..be5a82dcb47ad 100644 --- a/src/frontend/src/handler/use_db.rs +++ b/src/frontend/src/handler/use_db.rs @@ -14,7 +14,6 @@ use pgwire::pg_response::StatementType; use risingwave_common::acl::AclMode; -use risingwave_common::session_config::SessionConfig; use risingwave_pb::user::grant_privilege::Object as GrantObject; use risingwave_sqlparser::ast::ObjectName; @@ -49,8 +48,9 @@ pub async fn handle_use_db( session.user_name() )); - // reset session config - *session.shared_config().write() = SessionConfig::default(); + // reset search_path + session.reset_config("search_path")?; + session.update_database(database_name); Ok(builder.into()) diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index fad856f2273f1..9e183b7355e97 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -827,6 +827,13 @@ impl SessionImpl { .map_err(Into::into) } + pub fn reset_config(&self, key: &str) -> Result { + self.config_map + .write() + .reset(key, &mut ()) + .map_err(Into::into) + } + pub fn set_config_report( &self, key: &str, From 2f9a68daba869af394803bb5a61c99fb1f5a70b9 Mon Sep 17 00:00:00 2001 From: August Date: Mon, 30 Dec 2024 15:19:40 +0800 Subject: [PATCH 3/3] clippy --- src/frontend/src/handler/mod.rs | 7 +++---- src/frontend/src/handler/use_db.rs | 5 +---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 795a847427654..d4326baad0ad2 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -594,10 +594,9 @@ pub async fn handle( let res = use_db::handle_use_db( handler_args, ObjectName::from(vec![Ident::new_unchecked( - x.unwrap_or("default".to_string()), + x.unwrap_or("default".to_owned()), )]), - ) - .await?; + )?; let mut builder = RwPgResponse::builder(StatementType::SET_VARIABLE); for notice in res.notices() { builder = builder.notice(notice); @@ -1156,7 +1155,7 @@ pub async fn handle( object_name, comment, } => comment::handle_comment(handler_args, object_type, object_name, comment).await, - Statement::Use { db_name } => use_db::handle_use_db(handler_args, db_name).await, + Statement::Use { db_name } => use_db::handle_use_db(handler_args, db_name), _ => bail_not_implemented!("Unhandled statement: {}", stmt), } } diff --git a/src/frontend/src/handler/use_db.rs b/src/frontend/src/handler/use_db.rs index be5a82dcb47ad..ba0b10f90dc68 100644 --- a/src/frontend/src/handler/use_db.rs +++ b/src/frontend/src/handler/use_db.rs @@ -22,10 +22,7 @@ use crate::handler::privilege::ObjectCheckItem; use crate::handler::{HandlerArgs, RwPgResponse}; use crate::Binder; -pub async fn handle_use_db( - handler_args: HandlerArgs, - database_name: ObjectName, -) -> Result { +pub fn handle_use_db(handler_args: HandlerArgs, database_name: ObjectName) -> Result { let session = handler_args.session; let database_name = Binder::resolve_database_name(database_name)?;