Skip to content

Commit

Permalink
feat: update our cross schema check to cross catalog
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Jan 9, 2024
1 parent a0a31c8 commit dcecf47
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 47 deletions.
46 changes: 21 additions & 25 deletions src/catalog/src/table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::collections::HashMap;
use std::sync::Arc;

use common_catalog::consts::INFORMATION_SCHEMA_NAME;
use common_catalog::format_full_table_name;
use datafusion::common::{ResolvedTableReference, TableReference};
use datafusion::datasource::provider_as_source;
Expand All @@ -30,20 +29,20 @@ use crate::CatalogManagerRef;
pub struct DfTableSourceProvider {
catalog_manager: CatalogManagerRef,
resolved_tables: HashMap<String, Arc<dyn TableSource>>,
disallow_cross_schema_query: bool,
disallow_cross_catalog_query: bool,
default_catalog: String,
default_schema: String,
}

impl DfTableSourceProvider {
pub fn new(
catalog_manager: CatalogManagerRef,
disallow_cross_schema_query: bool,
disallow_cross_catalog_query: bool,
query_ctx: &QueryContext,
) -> Self {
Self {
catalog_manager,
disallow_cross_schema_query,
disallow_cross_catalog_query,
resolved_tables: HashMap::new(),
default_catalog: query_ctx.current_catalog().to_owned(),
default_schema: query_ctx.current_schema().to_owned(),
Expand All @@ -54,29 +53,18 @@ impl DfTableSourceProvider {
&'a self,
table_ref: TableReference<'a>,
) -> Result<ResolvedTableReference<'a>> {
if self.disallow_cross_schema_query {
if self.disallow_cross_catalog_query {
match &table_ref {
TableReference::Bare { .. } => (),
TableReference::Partial { schema, .. } => {
ensure!(
schema.as_ref() == self.default_schema
|| schema.as_ref() == INFORMATION_SCHEMA_NAME,
QueryAccessDeniedSnafu {
catalog: &self.default_catalog,
schema: schema.as_ref(),
}
);
}
TableReference::Partial { .. } => {}
TableReference::Full {
catalog, schema, ..
} => {
ensure!(
catalog.as_ref() == self.default_catalog
&& (schema.as_ref() == self.default_schema
|| schema.as_ref() == INFORMATION_SCHEMA_NAME),
catalog.as_ref() == self.default_catalog,
QueryAccessDeniedSnafu {
catalog: catalog.as_ref(),
schema: schema.as_ref()
schema: schema.as_ref(),
}
);
}
Expand Down Expand Up @@ -136,29 +124,29 @@ mod tests {
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());

let table_ref = TableReference::Partial {
schema: Cow::Borrowed("public"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());

let table_ref = TableReference::Partial {
schema: Cow::Borrowed("wrong_schema"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
assert!(result.is_err());
assert!(result.is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("public"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("wrong_catalog"),
Expand All @@ -172,20 +160,28 @@ mod tests {
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
let _ = table_provider.resolve_table_ref(table_ref).unwrap();
let result = table_provider.resolve_table_ref(table_ref);
assert!(result.is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
let _ = table_provider.resolve_table_ref(table_ref).unwrap();
assert!(table_provider.resolve_table_ref(table_ref).is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("dummy"),
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
assert!(table_provider.resolve_table_ref(table_ref).is_err());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("greptime_private"),
table: Cow::Borrowed("columns"),
};
assert!(table_provider.resolve_table_ref(table_ref).is_ok());
}
}
10 changes: 4 additions & 6 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ pub fn check_permission(
) -> Result<()> {
let need_validate = plugins
.get::<QueryOptions>()
.map(|opts| opts.disallow_cross_schema_query)
.map(|opts| opts.disallow_cross_catalog_query)
.unwrap_or_default();

if !need_validate {
Expand Down Expand Up @@ -520,7 +520,7 @@ mod tests {
let query_ctx = QueryContext::arc();
let plugins: Plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
disallow_cross_catalog_query: true,
});

let sql = r#"
Expand Down Expand Up @@ -556,8 +556,6 @@ mod tests {
}

let wrong = vec![
("", "wrongschema."),
("greptime.", "wrongschema."),
("wrongcatalog.", "public."),
("wrongcatalog.", "wrongschema."),
];
Expand Down Expand Up @@ -607,10 +605,10 @@ mod tests {
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
check_permission(plugins.clone(), &stmt[0], &query_ctx).unwrap();

let sql = "SHOW TABLES FROM wrongschema";
let sql = "SHOW TABLES FROM private";
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());
assert!(re.is_ok());

// test describe table
let sql = "DESC TABLE {catalog}{schema}demo;";
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl DfContextProviderAdapter {

let mut table_provider = DfTableSourceProvider::new(
engine_state.catalog_manager().clone(),
engine_state.disallow_cross_schema_query(),
engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);

Expand Down
4 changes: 2 additions & 2 deletions src/query/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl DfLogicalPlanner {

let table_provider = DfTableSourceProvider::new(
self.engine_state.catalog_manager().clone(),
self.engine_state.disallow_cross_schema_query(),
self.engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);

Expand Down Expand Up @@ -91,7 +91,7 @@ impl DfLogicalPlanner {
async fn plan_pql(&self, stmt: EvalStmt, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
let table_provider = DfTableSourceProvider::new(
self.engine_state.catalog_manager().clone(),
self.engine_state.disallow_cross_schema_query(),
self.engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);
PromPlanner::stmt_to_plan(table_provider, stmt)
Expand Down
14 changes: 4 additions & 10 deletions src/query/src/query_engine/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use common_catalog::consts::INFORMATION_SCHEMA_NAME;
use session::context::QueryContextRef;
use snafu::ensure;

use crate::error::{QueryAccessDeniedSnafu, Result};

#[derive(Default, Clone)]
pub struct QueryOptions {
pub disallow_cross_schema_query: bool,
pub disallow_cross_catalog_query: bool,
}

// TODO(shuiyisong): remove one method after #559 is done
Expand All @@ -29,13 +28,8 @@ pub fn validate_catalog_and_schema(
schema: &str,
query_ctx: &QueryContextRef,
) -> Result<()> {
// information_schema is an exception
if schema.eq_ignore_ascii_case(INFORMATION_SCHEMA_NAME) {
return Ok(());
}

ensure!(
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
catalog == query_ctx.current_catalog(),
QueryAccessDeniedSnafu {
catalog: catalog.to_string(),
schema: schema.to_string(),
Expand All @@ -57,8 +51,8 @@ mod tests {
let context = QueryContext::with("greptime", "public");

validate_catalog_and_schema("greptime", "public", &context).unwrap();
let re = validate_catalog_and_schema("greptime", "wrong_schema", &context);
assert!(re.is_err());
let re = validate_catalog_and_schema("greptime", "private_schema", &context);
assert!(re.is_ok());
let re = validate_catalog_and_schema("wrong_catalog", "public", &context);
assert!(re.is_err());
let re = validate_catalog_and_schema("wrong_catalog", "wrong_schema", &context);
Expand Down
4 changes: 2 additions & 2 deletions src/query/src/query_engine/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ impl QueryEngineState {
self.table_mutation_handler.as_ref()
}

pub(crate) fn disallow_cross_schema_query(&self) -> bool {
pub(crate) fn disallow_cross_catalog_query(&self) -> bool {
self.plugins
.map::<QueryOptions, _, _>(|x| x.disallow_cross_schema_query)
.map::<QueryOptions, _, _>(|x| x.disallow_cross_catalog_query)
.unwrap_or(false)
}

Expand Down
2 changes: 1 addition & 1 deletion src/query/src/tests/query_engine_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async fn test_query_validate() -> Result<()> {
// set plugins
let plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
disallow_cross_catalog_query: true,
});

let factory = QueryEngineFactory::new_with_plugins(catalog_list, None, None, false, plugins);
Expand Down

0 comments on commit dcecf47

Please sign in to comment.