diff --git a/Cargo.lock b/Cargo.lock index 2380452..ed21fad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,6 +121,12 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -359,6 +365,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "either" version = "1.13.0" @@ -536,8 +551,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -978,6 +995,21 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "078e285eafdfb6c4b434e0d31e8cfcb5115b651496faca5749b88fafd4f23bfd" +[[package]] +name = "jsonwebtoken" +version = "9.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" +dependencies = [ + "base64 0.21.7", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "kernel32-sys" version = "0.2.2" @@ -1186,6 +1218,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1309,6 +1366,16 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "pem" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1393,6 +1460,12 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1416,6 +1489,7 @@ dependencies = [ "chrono", "criterion", "json", + "jsonwebtoken", "log", "mockito", "openssl", @@ -1433,9 +1507,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" dependencies = [ "cfg-if 1.0.0", "indoc", @@ -1451,9 +1525,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" dependencies = [ "once_cell", "target-lexicon", @@ -1461,9 +1535,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" dependencies = [ "libc", "pyo3-build-config", @@ -1482,9 +1556,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1494,9 +1568,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" dependencies = [ "heck", "proc-macro2", @@ -2040,6 +2114,18 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "slab" version = "0.4.9" @@ -2202,6 +2288,37 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "itoa 1.0.11", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinytemplate" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index 8ee142c..c516946 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib", "rlib"] [dependencies] -pyo3 = { version = "0.22.1" } +pyo3 = { version = "0.22.4"} pyo3-log = { version = "0.11.0" } log = { version = "~0.4.4", default-features = false, features = ["std"] } serde = { version = "1.0", features = ["derive"] } @@ -23,6 +23,7 @@ reqwest = { version = "0.12.7", features = ["json", "rustls-tls", "rustls-tls-na tokio = { version = "1.40", features = ["rt-multi-thread", "full"] } url = { version = "2.5", features = [] } json = { version = "0.12" } +jsonwebtoken = {version = "9.3.0"} state = { version = "0.6"} diff --git a/py_src/fusion/fusion_filesystem.py b/py_src/fusion/fusion_filesystem.py index 889bb42..b8fa1f8 100644 --- a/py_src/fusion/fusion_filesystem.py +++ b/py_src/fusion/fusion_filesystem.py @@ -697,10 +697,13 @@ async def _meth(url: Any, kw: Any) -> None: await self._async_raise_not_found_for_status(resp, url) return await resp.json() # type: ignore except Exception as ex: # noqa: BLE001 + # wait 3 seconds before retrying + await asyncio.sleep(3 * (ex_cnt + 1)) + logger.debug(f"Failed to upload file: {ex}") ex_cnt += 1 last_ex = ex - raise Exception(f"Failed to upload file: {last_ex}, failed after {ex_cnt} exceptions.") + raise Exception(f"Failed to upload file: {last_ex}, failed after {ex_cnt} exceptions. {last_ex}") context = nullcontext(lpath) diff --git a/requirements-dev.lock b/requirements-dev.lock index 752e66e..8367941 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -34,8 +34,6 @@ annotated-types==0.6.0 # via pydantic anyio==4.3.0 # via jupyter-server -appnope==0.1.4 - # via ipykernel argcomplete==3.3.0 # via nox argon2-cffi==23.1.0 @@ -55,6 +53,7 @@ attrs==23.2.0 # via aiohttp-sse-client # via jsonschema # via referencing +auditwheel==6.1.0 azure-core==1.30.1 # via adlfs # via azure-identity @@ -120,6 +119,7 @@ cryptography==42.0.5 # via moto # via msal # via pyjwt + # via secretstorage debugpy==1.8.1 # via ipykernel decorator==5.1.1 @@ -228,6 +228,9 @@ jaraco-functools==4.0.1 # via keyring jedi==0.19.1 # via ipython +jeepney==0.8.0 + # via keyring + # via secretstorage jinja2==3.1.4 # via jupyter-server # via mike @@ -414,6 +417,7 @@ oauthlib==3.2.2 overrides==7.7.0 # via jupyter-server packaging==24.0 + # via auditwheel # via ipykernel # via jupyter-server # via jupytext @@ -434,6 +438,7 @@ pandocfilters==1.5.1 # via nbconvert parso==0.8.4 # via jedi +patchelf==0.17.2.1 pathspec==0.12.1 # via mkdocs pexpect==4.9.0 @@ -494,6 +499,8 @@ pydantic-core==2.18.2 # via pydantic pydantic-settings==2.2.1 # via bump-my-version +pyelftools==0.31 + # via auditwheel pygments==2.17.2 # via ipython # via mkdocs-jupyter @@ -621,6 +628,8 @@ s3fs==2024.3.1 # via pyfusion s3transfer==0.10.1 # via boto3 +secretstorage==3.3.3 + # via keyring send2trash==1.8.3 # via jupyter-server # via nbclassic @@ -767,6 +776,7 @@ xmltodict==0.13.0 yarl==1.9.4 # via aiohttp # via aiohttp-sse-client +ziglang==0.13.0 zipp==3.18.1 # via importlib-metadata # via importlib-resources diff --git a/requirements.lock b/requirements.lock index 3952119..51bc4e9 100644 --- a/requirements.lock +++ b/requirements.lock @@ -29,8 +29,6 @@ aiosignal==1.3.1 # via aiohttp anyio==4.3.0 # via jupyter-server -appnope==0.1.4 - # via ipykernel argon2-cffi==23.1.0 # via jupyter-server # via nbclassic diff --git a/src/auth.rs b/src/auth.rs index 5107659..0271b71 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,6 +1,7 @@ use bincode::{deserialize, serialize}; use chrono::{NaiveDate, Utc}; use json; +use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; use pyo3::exceptions::{PyFileNotFoundError, PyValueError}; use pyo3::import_exception; use pyo3::prelude::*; @@ -207,6 +208,8 @@ struct FusionCredsPersistent { //#[serde(deserialize_with = "deserialize_fusion_e2e")] fusion_e2e: Option, headers: Option>, + kid: Option, + private_key: Option, } #[pyclass(module = "fusion._fusion")] @@ -302,11 +305,13 @@ pub struct FusionCredentials { #[pyo3(get, set)] username: Option, + #[pyo3(get, set)] password: Option, #[pyo3(get, set)] resource: Option, + #[pyo3(get, set)] auth_url: Option, @@ -330,6 +335,12 @@ pub struct FusionCredentials { #[serde(skip)] http_client: Option, + + #[pyo3(get, set)] + kid: Option, + + #[pyo3(get, set)] + private_key: Option, } impl Default for FusionCredentials { @@ -347,11 +358,23 @@ impl Default for FusionCredentials { grant_type: "client_credentials".to_string(), fusion_e2e: None, headers: HashMap::new(), + kid: None, + private_key: None, http_client: None, } } } +#[derive(Serialize, Debug)] +struct Claims { + iss: String, + aud: String, + sub: String, + iat: i64, + exp: i64, + jti: String, +} + #[pymethods] impl FusionCredentials { fn __getstate__(&self) -> PyResult> { @@ -363,7 +386,6 @@ impl FusionCredentials { Ok(()) } - #[allow(clippy::type_complexity)] fn __getnewargs__( &self, ) -> PyResult<( @@ -395,7 +417,7 @@ impl FusionCredentials { } #[classmethod] - #[pyo3(signature = (client_id=None, client_secret=None, resource=None, auth_url=None, proxies=None, fusion_e2e=None, headers=None))] + #[pyo3(signature = (client_id=None, client_secret=None, resource=None, auth_url=None, proxies=None, fusion_e2e=None, headers=None, kid=None, private_key=None))] fn from_client_id( _cls: &Bound<'_, PyType>, client_id: Option, @@ -405,6 +427,8 @@ impl FusionCredentials { proxies: Option>, fusion_e2e: Option, headers: Option>, + kid: Option, + private_key: Option, ) -> PyResult { Ok(Self { client_id, @@ -415,6 +439,8 @@ impl FusionCredentials { grant_type: "client_credentials".to_string(), fusion_e2e, headers: headers.unwrap_or_default(), + kid, + private_key, fusion_token: HashMap::new(), bearer_token: None, username: None, @@ -432,7 +458,7 @@ impl FusionCredentials { } #[classmethod] - #[pyo3(signature = (client_id=None, username=None, password=None, resource=None, auth_url=None, proxies=None, fusion_e2e=None, headers=None))] + #[pyo3(signature = (client_id=None, username=None, password=None, resource=None, auth_url=None, proxies=None, fusion_e2e=None, headers=None, kid=None, private_key=None))] fn from_user_id( _cls: &Bound<'_, PyType>, client_id: Option, @@ -443,6 +469,8 @@ impl FusionCredentials { proxies: Option>, fusion_e2e: Option, headers: Option>, + kid: Option, + private_key: Option, ) -> PyResult { Ok(Self { client_id, @@ -454,6 +482,8 @@ impl FusionCredentials { grant_type: "password".to_string(), fusion_e2e, headers: headers.unwrap_or_default(), + kid, + private_key, fusion_token: HashMap::new(), bearer_token: None, client_secret: None, @@ -499,12 +529,14 @@ impl FusionCredentials { username: None, password: None, http_client: None, + kid: None, + private_key: None, }) } - #[allow(clippy::too_many_arguments)] + // #[allow(clippy::too_many_arguments)] #[new] - #[pyo3(signature = (client_id=None, client_secret=None, username=None, password=None, resource=None, auth_url=None, bearer_token=None, proxies=None, grant_type=None, fusion_e2e=None, headers=None))] + #[pyo3(signature = (client_id=None, client_secret=None, username=None, password=None, resource=None, auth_url=None, bearer_token=None, proxies=None, grant_type=None, fusion_e2e=None, headers=None, kid=None, private_key=None))] fn new( client_id: Option, client_secret: Option, @@ -517,6 +549,8 @@ impl FusionCredentials { grant_type: Option, fusion_e2e: Option, headers: Option>, + kid: Option, + private_key: Option, ) -> PyResult { Ok(FusionCredentials { client_id, @@ -531,6 +565,8 @@ impl FusionCredentials { grant_type: grant_type.unwrap_or_else(|| "client_credentials".to_string()), fusion_e2e, headers: headers.unwrap_or_default(), + kid, + private_key, http_client: None, }) } @@ -541,52 +577,105 @@ impl FusionCredentials { py: Python, force: bool, max_remain_secs: u32, - ) -> PyResult<()> { + ) -> PyResult { if !force { if let Some(token) = self.bearer_token.as_ref() { if !token.is_expirable() { - return Ok(()); + return Ok(false); } if let Some(expires_in_secs) = token.expires_in_secs() { if expires_in_secs > max_remain_secs as i64 { - return Ok(()); + return Ok(false); } } } } - self._ensure_http_client()?; let client = self.http_client.as_ref().ok_or_else(|| { CredentialError::new_err( "HTTP client not initialized. Use from_* methods to create credentials", ) })?; - let payload = match self.grant_type.clone().as_str() { - "client_credentials" => { - vec![ - ("grant_type", self.grant_type.as_str()), - ("client_id", self.client_id.as_ref().unwrap()), - ("client_secret", self.client_secret.as_ref().unwrap()), - ("aud", self.resource.as_ref().unwrap()), - ] - } - "password" => { - vec![ - ("grant_type", self.grant_type.as_str()), - ("client_id", self.client_id.as_ref().unwrap()), - ("username", self.username.as_ref().unwrap()), - ("password", self.password.as_ref().unwrap()), - ("resource", self.resource.as_ref().unwrap()), - ] - } - "bearer" => { - // Nothing to do - return Ok(()); - } - _ => { - return Err(PyValueError::new_err("Unrecognized grant type")); + + let payload = if let (Some(_kid), Some(private_key)) = (&self.kid, &self.private_key) { + // Create JWT claims + let claims = Claims { + iss: self.client_id.clone().unwrap_or_default(), + aud: self.auth_url.clone().unwrap_or_default(), + sub: self.client_id.clone().unwrap_or_default(), + iat: Utc::now().timestamp(), + exp: Utc::now().timestamp() + 3600, + jti: "id001".to_string(), + }; + // Encode the JWT + let private_key_bytes = private_key.as_bytes(); + let encoding_key = + EncodingKey::from_rsa_pem(private_key_bytes).expect("Invalid RSA private key"); + let mut header = Header::new(Algorithm::RS256); + header.kid = Some(self.kid.clone().unwrap_or_default()); + let private_key_jwt = + encode(&header, &claims, &encoding_key).expect("Failed to encode JWT"); + + // Build the payload vector + vec![ + ("grant_type".to_string(), self.grant_type.clone()), + ( + "client_id".to_string(), + self.client_id.clone().unwrap_or_default(), + ), + ( + "client_assertion_type".to_string(), + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer".to_string(), + ), + ("client_assertion".to_string(), private_key_jwt), + ( + "resource".to_string(), + self.resource.clone().unwrap_or_default(), + ), + ] + } else { + match self.grant_type.as_str() { + "client_credentials" => vec![ + ("grant_type".to_string(), self.grant_type.clone()), + ( + "client_id".to_string(), + self.client_id.clone().unwrap_or_default(), + ), + ( + "client_secret".to_string(), + self.client_secret.clone().unwrap_or_default(), + ), + ("aud".to_string(), self.resource.clone().unwrap_or_default()), + ], + "password" => vec![ + ("grant_type".to_string(), self.grant_type.clone()), + ( + "client_id".to_string(), + self.client_id.clone().unwrap_or_default(), + ), + ( + "username".to_string(), + self.username.clone().unwrap_or_default(), + ), + ( + "password".to_string(), + self.password.clone().unwrap_or_default(), + ), + ( + "resource".to_string(), + self.resource.clone().unwrap_or_default(), + ), + ], + "bearer" => { + // Nothing to do + return Ok(true); + } + _ => { + return Err(PyValueError::new_err("Unrecognized grant type")); + } } }; + let rt = &get_tokio_runtime(py).0; let response_res: PyResult = rt.block_on(async { @@ -624,7 +713,7 @@ impl FusionCredentials { } } self.put_bearer_token(token, expires_in_secs); - Ok(()) + Ok((true)) } #[pyo3(signature = (bearer_token, expires_in_secs=None))] @@ -718,7 +807,7 @@ impl FusionCredentials { } } - self._refresh_bearer_token(py, false, 30)?; + let is_bearer_refreshed = self._refresh_bearer_token(py, false, 15 * 60)?; let bearer_token_tup = self .bearer_token .as_ref() @@ -744,7 +833,7 @@ impl FusionCredentials { std::collections::hash_map::Entry::Occupied(mut entry) => { let token = entry.get_mut(); if let Some(expires_in_secs) = token.expires_in_secs() { - if expires_in_secs > 30 { + if expires_in_secs < 15 * 60 || is_bearer_refreshed { (None, Some(self._gen_fusion_token(py, fusion_tk_url)?)) } else { (Some(token.as_fusion_header()?), None) @@ -820,6 +909,8 @@ impl FusionCredentials { Some(untyped_proxies(credentials.proxies)), credentials.fusion_e2e, credentials.headers, + credentials.kid, + credentials.private_key, )?, "bearer" => FusionCredentials::from_bearer_token( cls, @@ -839,6 +930,8 @@ impl FusionCredentials { Some(untyped_proxies(credentials.proxies)), credentials.fusion_e2e, credentials.headers, + credentials.kid, + credentials.private_key, )?, _ => { return Err(pyo3::exceptions::PyValueError::new_err( @@ -849,7 +942,6 @@ impl FusionCredentials { Ok(full_creds) } } - // Tests #[cfg(test)] @@ -1050,6 +1142,8 @@ mod tests { Some("grant_type".to_string()), Some("fusion_e2e".to_string()), Some(HashMap::new()), + Some("kid".to_string()), + Some("private_key".to_string()), ) .unwrap(); @@ -1094,6 +1188,8 @@ mod tests { Some("grant_type".to_string()), Some("fusion_e2e".to_string()), Some(HashMap::new()), + Some("kid".to_string()), + Some("private_key".to_string()), ) .unwrap(); @@ -1124,6 +1220,8 @@ mod tests { Some("grant_type".to_string()), Some("fusion_e2e".to_string()), Some(HashMap::new()), + Some("kid".to_string()), + Some("private_key".to_string()), ) .unwrap(); @@ -1155,6 +1253,8 @@ mod tests { Some("grant_type".to_string()), Some("fusion_e2e".to_string()), Some(HashMap::new()), + Some("kid".to_string()), + Some("private_key".to_string()), ) .unwrap(); @@ -1183,6 +1283,8 @@ mod tests { assert!(headers.is_some()); assert_eq!(grant_type, Some("grant_type".to_string())); assert_eq!(fusion_e2e, Some("fusion_e2e".to_string())); + // assert_eq!(kid, Some("kid".to_string())); + // assert_eq!(private_key, Some("private_key".to_string())); } #[test] @@ -1198,6 +1300,8 @@ mod tests { None, None, None, + None, + None, ) .unwrap(); @@ -1223,6 +1327,8 @@ mod tests { None, None, None, + None, + None, ) .unwrap();