Skip to content

Commit

Permalink
feat: implement authentication and access filtering for pgwire using JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
Solomon committed Dec 14, 2023
1 parent c9df9af commit a5ccf78
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 46 deletions.
8 changes: 1 addition & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dozer-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ pgwire = "0.16.1"
tempdir = "0.3.7"
postgres-types = "0.2"
futures-sink = "0.3.29"
async-once-cell = "0.5.3"
genawaiter = "0.99.1"
once_cell = "1.18.0"
60 changes: 47 additions & 13 deletions dozer-api/src/sql/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,26 @@ use dozer_types::types::Schema as DozerSchema;
use futures_util::future::try_join_all;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use once_cell::sync::OnceCell;

use crate::api_helper::get_records;
use crate::auth::Access;
use crate::CacheEndpoint;

use predicate_pushdown::{predicate_pushdown, supports_predicates_pushdown};

pub(crate) struct SQLExecutor {
ctx: Arc<SessionContext>,
ctx: SessionContext,
access: Arc<OnceCell<Access>>,
}

impl Clone for SQLExecutor {
fn clone(&self) -> Self {
Self {
ctx: SessionContext::new_with_state(self.ctx.state()),
access: Default::default(),
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -520,8 +532,9 @@ impl SQLExecutor {
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_default_catalog_and_schema("dozer", "public"),
);
let access = Arc::new(OnceCell::<Access>::new());
for cache_endpoint in cache_endpoints {
let data_source = CacheEndpointDataSource::new(cache_endpoint.clone());
let data_source = CacheEndpointDataSource::new(cache_endpoint.clone(), access.clone());
let _provider = ctx
.register_table(
TableReference::Bare {
Expand All @@ -538,7 +551,11 @@ impl SQLExecutor {

pg_catalog::create(&ctx).await?;

Ok(Self { ctx: Arc::new(ctx) })
Ok(Self { ctx, access })
}

pub fn set_access(&self, access: Access) -> Result<(), Access> {
self.access.set(access)
}

pub async fn execute(&self, plan: LogicalPlan) -> Result<DataFrame, DataFusionError> {
Expand Down Expand Up @@ -613,10 +630,11 @@ impl SQLExecutor {
pub struct CacheEndpointDataSource {
cache_endpoint: Arc<CacheEndpoint>,
schema: SchemaRef,
access: Arc<OnceCell<Access>>,
}

impl CacheEndpointDataSource {
pub fn new(cache_endpoint: Arc<CacheEndpoint>) -> Self {
pub fn new(cache_endpoint: Arc<CacheEndpoint>, access: Arc<OnceCell<Access>>) -> Self {
let schema = {
let cache_reader = &cache_endpoint.cache_reader();
let schema = &cache_reader.get_schema().0;
Expand All @@ -625,6 +643,7 @@ impl CacheEndpointDataSource {
Self {
cache_endpoint,
schema,
access,
}
}
}
Expand Down Expand Up @@ -657,6 +676,7 @@ impl TableProvider for CacheEndpointDataSource {
projection,
filters.to_vec(),
limit,
self.access.clone(),
)?))
}

Expand All @@ -678,6 +698,7 @@ pub struct CacheEndpointExec {
projected_schema: SchemaRef,
filters: Vec<Expr>,
limit: Option<usize>,
access: Arc<OnceCell<Access>>,
}

impl CacheEndpointExec {
Expand All @@ -688,6 +709,7 @@ impl CacheEndpointExec {
projection: Option<&Vec<usize>>,
filters: Vec<Expr>,
limit: Option<usize>,
access: Arc<OnceCell<Access>>,
) -> Result<Self> {
let projected_schema = match projection {
Some(p) => Arc::new(schema.project(p)?),
Expand All @@ -700,6 +722,7 @@ impl CacheEndpointExec {
projection: projection.cloned().map(Into::into),
filters,
limit,
access,
})
}
}
Expand Down Expand Up @@ -742,24 +765,35 @@ impl ExecutionPlan for CacheEndpointExec {
_partition: usize,
_ctx: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let stream = futures_util::stream::iter({
let stream = {
let cache_reader = &self.cache_endpoint.cache_reader();
let mut expr = QueryExpression {
limit: self.limit,
filter: predicate_pushdown(self.filters.iter()),
..Default::default()
};
debug!("Using predicate pushdown {:?}", expr.filter);
let records = get_records(
let access = self.access.get().cloned();
debug!(
"Using predicate pushdown {:?} with access {:?}",
expr.filter, access
);
match get_records(
cache_reader,
&mut expr,
&self.cache_endpoint.endpoint.name,
None,
)
.unwrap();

transpose(cache_reader.get_schema().0.clone(), records)
});
access,
) {
Ok(records) => futures_util::stream::iter(transpose(
cache_reader.get_schema().0.clone(),
records,
))
.boxed(),
Err(err) => futures_util::stream::once(futures_util::future::ready(Err(
DataFusionError::External(err.into()),
)))
.boxed(),
}
};
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.projected_schema.clone(),
match self.projection.clone() {
Expand Down
Loading

0 comments on commit a5ccf78

Please sign in to comment.