diff --git a/src/_bcrypt/src/lib.rs b/src/_bcrypt/src/lib.rs index 2da2b021..3159bf95 100644 --- a/src/_bcrypt/src/lib.rs +++ b/src/_bcrypt/src/lib.rs @@ -83,14 +83,14 @@ fn hashpw<'p>( // salt here is not just the salt bytes, but rather an encoded value // containing a version number, number of rounds, and the salt. // Should be [prefix, cost, hash]. This logic is copied from `bcrypt` - let raw_parts: Vec<_> = salt + let [raw_version, raw_cost, remainder]: [&[u8]; 3] = salt .split(|&b| b == b'$') .filter(|s| !s.is_empty()) - .collect(); - if raw_parts.len() != 3 { - return Err(pyo3::exceptions::PyValueError::new_err("Invalid salt")); - } - let version = match raw_parts[0] { + .collect::>() + .try_into() + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid salt"))?; + + let version = match raw_version { b"2y" => bcrypt::Version::TwoY, b"2b" => bcrypt::Version::TwoB, b"2a" => bcrypt::Version::TwoA, @@ -99,15 +99,20 @@ fn hashpw<'p>( return Err(pyo3::exceptions::PyValueError::new_err("Invalid salt")); } }; - let cost = std::str::from_utf8(raw_parts[1]) + let cost = std::str::from_utf8(raw_cost) .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid salt"))? .parse::() .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid salt"))?; + + if remainder.len() < 22 { + return Err(pyo3::exceptions::PyValueError::new_err("Invalid salt")); + } + // The last component can contain either just the salt, or the salt and // the result hash, depending on if the `salt` value come from `hashpw` or // `gensalt`. let raw_salt = BASE64_ENGINE - .decode(&raw_parts[2][..22]) + .decode(&remainder[..22]) .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid salt"))? .try_into() .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid salt"))?; diff --git a/tests/test_bcrypt.py b/tests/test_bcrypt.py index 68c00fb4..b0e0182a 100644 --- a/tests/test_bcrypt.py +++ b/tests/test_bcrypt.py @@ -275,6 +275,8 @@ def test_checkpw_bad_salt(): b"password", b"$2b$3$mdEQPMOtfPX.WGZNXgF66OhmBlOGKEd66SQ7DyJPGucYYmvTJYviy", ) + with pytest.raises(ValueError): + bcrypt.checkpw(b"password", b"$2b$12$incorrect") def test_checkpw_str_password():