Skip to content

Commit

Permalink
dekaf: Implement field selection based on data from the built spec
Browse files Browse the repository at this point in the history
  • Loading branch information
jshearer committed Jan 13, 2025
1 parent a5cb792 commit 457cb62
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 79 deletions.
8 changes: 3 additions & 5 deletions crates/dekaf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ pub struct TaskAuth {
client: flow_client::Client,
task_name: String,
config: DekafConfig,
spec: models::MaterializationDef,
ops_logs_journal: String,
built_spec: proto_flow::flow::MaterializationSpec,
ops_stats_journal: String,

// When access token expires
Expand Down Expand Up @@ -188,7 +187,7 @@ impl App {
);
// Ask the agent for information about this task, as well as a short-lived
// control-plane access token authorized to interact with the avro schemas table
let (client, claims, ops_logs_journal, ops_stats_journal, task_spec) =
let (client, claims, _, ops_stats_journal, task_spec) =
topology::fetch_dekaf_task_auth(
self.client_base.clone(),
&username,
Expand All @@ -208,8 +207,7 @@ impl App {
Ok(SessionAuthentication::Task(TaskAuth {
task_name: username,
config,
spec: task_spec,
ops_logs_journal,
built_spec: task_spec,
ops_stats_journal,
client: client.with_fresh_gazette_client(),
exp: time::OffsetDateTime::UNIX_EPOCH + time::Duration::seconds(claims.exp as i64),
Expand Down
190 changes: 117 additions & 73 deletions crates/dekaf/src/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use flow_client::fetch_task_authorization;
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use gazette::{broker, journal, uuid};
use itertools::Itertools;
use models::MaterializationBinding;
use models::RawValue;
use proto_flow::flow;
use std::time::Duration;
use std::{iter, time::Duration};

impl UserAuth {
/// Fetch the names of all collections which the current user may read.
Expand Down Expand Up @@ -38,12 +38,12 @@ impl UserAuth {
impl TaskAuth {
pub async fn fetch_all_collection_names(&self) -> anyhow::Result<Vec<String>> {
Ok(self
.spec
.built_spec
.bindings
.iter()
.map(|b| {
serde_json::from_value::<crate::connector::DekafResourceConfig>(
b.resource.to_value(),
serde_json::from_str::<crate::connector::DekafResourceConfig>(
&b.resource_config_json,
)
})
.map_ok(|val| val.topic_name)
Expand All @@ -53,14 +53,19 @@ impl TaskAuth {
pub fn get_binding_for_topic(
&self,
topic_name: &str,
) -> anyhow::Result<Option<(MaterializationBinding, DekafResourceConfig)>> {
) -> anyhow::Result<
Option<(
proto_flow::flow::materialization_spec::Binding,
DekafResourceConfig,
)>,
> {
Ok(self
.spec
.built_spec
.bindings
.iter()
.map(|b| {
serde_json::from_value::<crate::connector::DekafResourceConfig>(
b.resource.to_value(),
serde_json::from_str::<crate::connector::DekafResourceConfig>(
&b.resource_config_json,
)
.map(|parsed| (b, parsed))
})
Expand All @@ -87,7 +92,10 @@ impl SessionAuthentication {
.get_binding_for_topic(topic_name)?
.ok_or(anyhow::anyhow!("Unrecognized topic {topic_name}"))?;

Ok(binding.source.collection().to_string())
Ok(binding
.collection
.context("missing collection in materialization binding")?
.name)
}
}
}
Expand Down Expand Up @@ -134,23 +142,19 @@ impl Collection {
) -> anyhow::Result<Option<Self>> {
let not_before = uuid::Clock::default();

if let SessionAuthentication::Task(task_auth) = auth {
let binding = if let SessionAuthentication::Task(task_auth) = auth {
if let Some((binding, _)) = task_auth.get_binding_for_topic(topic_name)? {
if binding.disable {
bail!(
"Binding for topic {topic_name} is disabled in {}",
task_auth.task_name
)
}
} else if let Some(suggested_binding) = task_auth
.spec
.bindings
.iter()
.find(|b| b.source.collection().to_string() == topic_name)
{
let correct_topic_name = serde_json::from_value::<
Some(binding)
} else if let Some(suggested_binding) = task_auth.built_spec.bindings.iter().find(|b| {
b.collection
.as_ref()
.expect("missing collection in materialization binding")
.name
== topic_name
}) {
let correct_topic_name = serde_json::from_str::<
crate::connector::DekafResourceConfig,
>(suggested_binding.resource.to_value())?
>(&suggested_binding.resource_config_json)?
.topic_name;
bail!(
"{topic_name} is not a binding of {}. Did you mean {}?",
Expand All @@ -160,14 +164,16 @@ impl Collection {
} else {
bail!("{topic_name} is not a binding of {}", task_auth.task_name)
}
}
} else {
None
};

let collection_name = &auth.get_collection_for_topic(topic_name)?;

let Some(spec) = Self::fetch_spec(&pg_client, collection_name).await? else {
let Some(collection_spec) = Self::fetch_spec(&pg_client, collection_name).await? else {
return Ok(None);
};
let partition_template_name = spec
let partition_template_name = collection_spec
.partition_template
.as_ref()
.map(|spec| spec.name.to_owned())
Expand All @@ -180,44 +186,82 @@ impl Collection {

tracing::debug!(?partitions, "Got partitions");

let key_ptr: Vec<doc::Pointer> =
spec.key.iter().map(|p| doc::Pointer::from_str(p)).collect();
let uuid_ptr = doc::Pointer::from_str(&spec.uuid_ptr);

// Extract projections from the spec
let projections = spec.projections.clone();
let key_ptr: Vec<doc::Pointer> = collection_spec
.key
.iter()
.map(|p| doc::Pointer::from_str(p))
.collect();
let uuid_ptr = doc::Pointer::from_str(&collection_spec.uuid_ptr);

let json_schema = if spec.read_schema_json.is_empty() {
&spec.write_schema_json
let json_schema = if collection_spec.read_schema_json.is_empty() {
&collection_spec.write_schema_json
} else {
&spec.read_schema_json
&collection_spec.read_schema_json
};

let json_schema = doc::validation::build_bundle(json_schema)?;
let validator = doc::Validator::new(json_schema)?;
let shape = doc::Shape::infer(&validator.schemas()[0], validator.schema_index());
let collection_schema_shape =
doc::Shape::infer(&validator.schemas()[0], validator.schema_index());

// Create value shape by merging all projected fields in the schema
let mut value_shape = doc::Shape::nothing();

// Add projected fields to value shape
for projection in &projections {
let ptr = doc::Pointer::from_str(&projection.ptr);
let (field_shape, exists) = shape.locate(&ptr);
if exists.cannot() {
tracing::warn!(
projection = projection.ptr,
"Projection field not found in schema"
);
continue;
}
let (mut field_selected_shape, projections) = if let Some(binding) = binding {
let selection = binding
.field_selection
.context("missing field selection in materialization binding")?;

let projected_shape = build_shape_at_pointer(&ptr, field_shape);
value_shape = doc::Shape::union(value_shape, projected_shape);
}
let selected_projections = selection
.keys
.iter()
.chain(selection.values.iter())
.chain(iter::once(&selection.document))
.filter(|field| field.len() > 0)
.map(|field| {
let projection = collection_spec
.projections
.iter()
.find(|proj| proj.field == *field);
if let Some(projection) = projection {
Some(projection.clone())
} else {
tracing::warn!(
?field,
"Missing projection for field on materialization built spec"
);
None
}
})
.flatten() // transform from Option<T> to T by filtering out Nones
.collect_vec();

let mapped_shape = selected_projections.iter().fold(
doc::Shape::nothing(),
|value_shape, projection| {
let source_ptr = doc::Pointer::from_str(&projection.ptr);
let (source_shape, exists) = collection_schema_shape.locate(&source_ptr);
if exists.cannot() {
tracing::warn!(
projection = ?source_ptr,
"Projection field not found in schema"
);
value_shape
} else {
let nested_shape = build_shape_at_pointer(
&doc::Pointer::from_str(&projection.field),
source_shape,
);
doc::Shape::union(value_shape, nested_shape)
}
},
);

(mapped_shape, selected_projections)
} else {
(collection_schema_shape, collection_spec.projections.clone())
};

if matches!(auth.deletions(), DeletionMode::CDC) {
if let Some(meta) = value_shape
if let Some(meta) = field_selected_shape
.object
.properties
.iter_mut()
Expand Down Expand Up @@ -252,7 +296,7 @@ impl Collection {
}
}

let (key_schema, value_schema) = avro::shape_to_avro(shape, &key_ptr);
let (key_schema, value_schema) = avro::shape_to_avro(field_selected_shape, &key_ptr);

tracing::debug!(
collection_name,
Expand All @@ -266,7 +310,7 @@ impl Collection {
key_schema,
not_before,
partitions,
spec,
spec: collection_spec,
uuid_ptr,
value_schema,
projections,
Expand Down Expand Up @@ -545,7 +589,7 @@ pub async fn fetch_dekaf_task_auth(
AccessTokenClaims,
String,
String,
models::MaterializationDef,
proto_flow::flow::MaterializationSpec,
)> {
let request_token = flow_client::client::build_task_authorization_request_token(
shard_template_id,
Expand Down Expand Up @@ -580,32 +624,32 @@ pub async fn fetch_dekaf_task_auth(
break response;
};
let claims = flow_client::parse_jwt_claims(token.as_str())?;

Ok((
client.with_user_access_token(Some(token)),
claims,
ops_logs_journal,
ops_stats_journal,
task_spec.ok_or(anyhow::anyhow!(
"task_spec is only None when we need to retry the auth request"
))?,
serde_json::from_str(
task_spec
.ok_or(anyhow::anyhow!(
"task_spec is only None when we need to retry the auth request"
))?
.get(),
)?,
))
}

pub async fn extract_dekaf_config(
spec: &models::MaterializationDef,
spec: &proto_flow::flow::MaterializationSpec,
) -> anyhow::Result<DekafConfig> {
match &spec.endpoint {
models::MaterializationEndpoint::Dekaf(dekaf_endpoint) => {
let decrypted = unseal::decrypt_sops(&dekaf_endpoint.config).await?;
let config = serde_json::from_str::<models::DekafConfig>(&spec.config_json)?;

let dekaf_config = serde_json::from_value::<DekafConfig>(decrypted.to_value())?;
Ok(dekaf_config)
}
models::MaterializationEndpoint::Connector(_)
| models::MaterializationEndpoint::Local(_) => {
bail!("not a Dekaf materialization")
}
}
let decrypted_endpoint_config =
unseal::decrypt_sops(&RawValue::from_str(&config.config.to_string())?).await?;

let dekaf_config = serde_json::from_str::<DekafConfig>(&decrypted_endpoint_config.to_string())?;
Ok(dekaf_config)
}

/// Nests the provided shape under a JSON pointer path by creating the necessary object hierarchy.
Expand Down
2 changes: 1 addition & 1 deletion go.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

PROFILE="${PROFILE:-release}"

export CGO_LDFLAGS="-L $(pwd)/target/${CARGO_BUILD_TARGET}/${PROFILE} -L $(pwd)/target/${CARGO_BUILD_TARGET}/${PROFILE}/librocksdb-exp -lbindings -lrocksdb -lsnappy -lstdc++ -lssl -lcrypto -ldl -lm"
export CGO_LDFLAGS="-L $(pwd)/target/${CARGO_BUILD_TARGET}/${PROFILE} -L $(pwd)/target/${CARGO_BUILD_TARGET}/${PROFILE}/librocksdb-exp -lbindings -lrocksdb -lsnappy -lstdc++ -ldl -lm"
if [ "$(uname)" == "Darwin" ]; then
export CGO_CFLAGS="-I $(pwd)/target/${CARGO_BUILD_TARGET}/${PROFILE}/librocksdb-exp/include -I $(brew --prefix)/include -I $(brew --prefix)/opt/sqlite3/include"
export CC="$(brew --prefix)/opt/llvm/bin/clang"
Expand Down

0 comments on commit 457cb62

Please sign in to comment.