diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 4e61adbe..db5b0335 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -292,6 +292,8 @@ pub enum RunLimit { TooManyIterations, #[error("spent too much time verifying")] Timeout, + #[error("Unexpected query results, expected {0} got {1}")] + UnexpectedQueryResult(usize, usize), } #[cfg(test)] diff --git a/biscuit-auth/src/token/authorizer.rs b/biscuit-auth/src/token/authorizer.rs index 040270a0..0fdb8209 100644 --- a/biscuit-auth/src/token/authorizer.rs +++ b/biscuit-auth/src/token/authorizer.rs @@ -154,6 +154,39 @@ impl Authorizer { self.query_with_limits(rule, limits) } + /// Run a query over the authorizer's Datalog engine to gather data. + /// If there is more than one result, this function will throw an error. + /// + /// ```rust + /// # use biscuit_auth::KeyPair; + /// # use biscuit_auth::Biscuit; + /// let keypair = KeyPair::new(); + /// let builder = Biscuit::builder().fact("user(\"John Doe\", 42)").unwrap(); + /// + /// let biscuit = builder.build(&keypair).unwrap(); + /// + /// let mut authorizer = biscuit.authorizer().unwrap(); + /// let res: (String, i64) = authorizer.query_exactly_one("data($name, $id) <- user($name, $id)").unwrap(); + /// assert_eq!(res.0, "John Doe"); + /// assert_eq!(res.1, 42); + /// ``` + pub fn query_exactly_one, T: TryFrom, E: Into>( + &mut self, + rule: R, + ) -> Result + where + error::Token: From<>::Error>, + { + let mut res: Vec = self.query(rule)?; + if res.len() == 1 { + Ok(res.remove(0)) + } else { + Err(error::Token::RunLimit( + error::RunLimit::UnexpectedQueryResult(1, res.len()), + )) + } + } + /// run a query over the authorizer's Datalog engine to gather data /// /// this only sees facts from the authorizer and the authority block @@ -1048,6 +1081,62 @@ mod tests { assert_eq!(res[0].0, "John Doe"); } + #[test] + fn query_exactly_one_authorizer_from_token_string() { + use crate::Biscuit; + use crate::KeyPair; + let keypair = KeyPair::new(); + let builder = Biscuit::builder().fact("user(\"John Doe\")").unwrap(); + + let biscuit = builder.build(&keypair).unwrap(); + + let mut authorizer = biscuit.authorizer().unwrap(); + let res: (String,) = authorizer + .query_exactly_one("data($name) <- user($name)") + .unwrap(); + assert_eq!(res.0, "John Doe"); + } + + #[test] + fn query_exactly_one_no_results() { + use crate::Biscuit; + use crate::KeyPair; + let keypair = KeyPair::new(); + let builder = Biscuit::builder(); + + let biscuit = builder.build(&keypair).unwrap(); + + let mut authorizer = biscuit.authorizer().unwrap(); + let res: Result<(String,), error::Token> = + authorizer.query_exactly_one("data($name) <- user($name)"); + assert_eq!( + res.unwrap_err().to_string(), + "Reached Datalog execution limits" + ); + } + + #[test] + fn query_exactly_one_too_many_results() { + use crate::Biscuit; + use crate::KeyPair; + let keypair = KeyPair::new(); + let builder = Biscuit::builder() + .fact("user(\"John Doe\")") + .unwrap() + .fact("user(\"Jane Doe\")") + .unwrap(); + + let biscuit = builder.build(&keypair).unwrap(); + + let mut authorizer = biscuit.authorizer().unwrap(); + let res: Result<(String,), error::Token> = + authorizer.query_exactly_one("data($name) <- user($name)"); + assert_eq!( + res.unwrap_err().to_string(), + "Reached Datalog execution limits" + ); + } + #[test] fn authorizer_with_scopes() { let root = KeyPair::new(); diff --git a/biscuit-capi/src/lib.rs b/biscuit-capi/src/lib.rs index 2900c26c..c3c7e35e 100644 --- a/biscuit-capi/src/lib.rs +++ b/biscuit-capi/src/lib.rs @@ -98,6 +98,7 @@ pub enum ErrorKind { FormatSignatureInvalidSignatureGeneration, AlreadySealed, Execution, + UnexpectedQueryResult, } #[no_mangle] @@ -175,6 +176,9 @@ pub extern "C" fn error_kind() -> ErrorKind { Token::RunLimit(RunLimit::TooManyFacts) => ErrorKind::TooManyFacts, Token::RunLimit(RunLimit::TooManyIterations) => ErrorKind::TooManyIterations, Token::RunLimit(RunLimit::Timeout) => ErrorKind::Timeout, + Token::RunLimit(RunLimit::UnexpectedQueryResult(_, _)) => { + ErrorKind::UnexpectedQueryResult + } Token::ConversionError(_) => ErrorKind::ConversionError, Token::Base64(_) => ErrorKind::FormatDeserializationError, Token::Execution(_) => ErrorKind::Execution,