Skip to content

Commit

Permalink
Move schema_name / catalog_name parameters into resolve function and …
Browse files Browse the repository at this point in the history
…out of trait
  • Loading branch information
westonpace committed Jan 3, 2025
1 parent 29e8976 commit a6e17f1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 46 deletions.
12 changes: 2 additions & 10 deletions datafusion-examples/examples/remote_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async fn main() -> Result<()> {
// Now we can asynchronously resolve the table references to get a cached catalog
// that we can use for our query
let resolved_catalog = remote_catalog_adapter
.resolve(&references, state.config())
.resolve(&references, state.config(), "datafusion", "remote_schema")
.await?;

// This resolved catalog only makes sense for this query and so we create a clone
Expand Down Expand Up @@ -177,20 +177,12 @@ impl RemoteCatalogInterface {
}
}

/// Implements the DataFusion SchemaProvider API for tables
/// Implements an async version of the DataFusion SchemaProvider API for tables
/// stored in a remote catalog.
struct RemoteCatalogDatafusionAdapter(Arc<RemoteCatalogInterface>);

#[async_trait]
impl AsyncSchemaProvider for RemoteCatalogDatafusionAdapter {
fn name(&self) -> &str {
"remote_schema"
}

fn catalog_name(&self) -> &str {
"datafusion"
}

async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
// Fetch information about the table from the remote catalog
//
Expand Down
54 changes: 18 additions & 36 deletions datafusion/catalog/src/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,6 @@ impl CatalogProviderList for ResolvedCatalogProviderList {
/// method can be slow and asynchronous as it is only called once, before planning.
#[async_trait]
pub trait AsyncSchemaProvider: Send + Sync {
/// Return the name of the schema provided by this provider
///
/// If a table reference's schema name does not match this name then the reference will be ignored
/// when calculating the cached set of tables (this allows other providers to supply the table)
fn name(&self) -> &str;
/// Return the name of the catalog this provider belongs to
///
/// If a table reference's catalog name does not match this name then the reference will be ignored
/// when calculating the cached set of tables (this allows other providers to supply the table)
fn catalog_name(&self) -> &str;
/// Lookup a table in the schema provider
async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>>;
/// Creates a cached provider that can be used to execute a query containing given references
Expand All @@ -216,24 +206,26 @@ pub trait AsyncSchemaProvider: Send + Sync {
&self,
references: &[TableReference],
config: &SessionConfig,
catalog_name: &str,
schema_name: &str,
) -> Result<Arc<dyn SchemaProvider>> {
let mut cached_tables = HashMap::<String, Option<Arc<dyn TableProvider>>>::new();

for reference in references {
let catalog_name = reference
let ref_catalog_name = reference
.catalog()
.unwrap_or(&config.options().catalog.default_catalog);

// Maybe this is a reference to some other catalog provided in another way
if catalog_name != self.catalog_name() {
if ref_catalog_name != catalog_name {
continue;
}

let schema_name = reference
let ref_schema_name = reference
.schema()
.unwrap_or(&config.options().catalog.default_schema);

if schema_name != self.name() {
if ref_schema_name != schema_name {
continue;
}

Expand All @@ -250,7 +242,7 @@ pub trait AsyncSchemaProvider: Send + Sync {

Ok(Arc::new(ResolvedSchemaProvider {
cached_tables,
owner_name: Some(self.catalog_name().to_string()),
owner_name: Some(catalog_name.to_string()),
}))
}
}
Expand All @@ -263,12 +255,6 @@ pub trait AsyncSchemaProvider: Send + Sync {
#[async_trait]
pub trait AsyncCatalogProvider: Send + Sync {
/// Returns the name of the catalog being provided
///
/// If a reference's catalog name does not match this name then the reference will be ignored.
/// This allows other providers to potentially provide the reference.
fn name(&self) -> &str;

/// Lookup a schema in the provider
async fn schema(&self, name: &str) -> Result<Option<Arc<dyn AsyncSchemaProvider>>>;

Expand All @@ -285,17 +271,18 @@ pub trait AsyncCatalogProvider: Send + Sync {
&self,
references: &[TableReference],
config: &SessionConfig,
catalog_name: &str,
) -> Result<Arc<dyn CatalogProvider>> {
let mut cached_schemas =
HashMap::<String, Option<ResolvedSchemaProviderBuilder>>::new();

for reference in references {
let catalog_name = reference
let ref_catalog_name = reference
.catalog()
.unwrap_or(&config.options().catalog.default_catalog);

// Maybe this is a reference to some other catalog provided in another way
if catalog_name != self.name() {
if ref_catalog_name != catalog_name {
continue;
}

Expand Down Expand Up @@ -491,12 +478,6 @@ mod tests {

#[async_trait]
impl AsyncSchemaProvider for MockAsyncSchemaProvider {
fn name(&self) -> &str {
MOCK_SCHEMA
}
fn catalog_name(&self) -> &str {
MOCK_CATALOG
}
async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
self.lookup_count.fetch_add(1, Ordering::Release);
if name == MOCK_TABLE {
Expand All @@ -523,8 +504,10 @@ mod tests {
not_found_tables: &[&str],
) {
let async_provider = MockAsyncSchemaProvider::default();
let cached_provider =
async_provider.resolve(&refs, &test_config()).await.unwrap();
let cached_provider = async_provider
.resolve(&refs, &test_config(), MOCK_CATALOG, MOCK_SCHEMA)
.await
.unwrap();

assert_eq!(
async_provider.lookup_count.load(Ordering::Acquire),
Expand Down Expand Up @@ -587,9 +570,6 @@ mod tests {

#[async_trait]
impl AsyncCatalogProvider for MockAsyncCatalogProvider {
fn name(&self) -> &str {
MOCK_CATALOG
}
async fn schema(
&self,
name: &str,
Expand All @@ -612,8 +592,10 @@ mod tests {
not_found_schemas: &[&str],
) {
let async_provider = MockAsyncCatalogProvider::default();
let cached_provider =
async_provider.resolve(&refs, &test_config()).await.unwrap();
let cached_provider = async_provider
.resolve(&refs, &test_config(), MOCK_CATALOG)
.await
.unwrap();

assert_eq!(
async_provider.lookup_count.load(Ordering::Acquire),
Expand Down

0 comments on commit a6e17f1

Please sign in to comment.