Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using password-protected private keys #797

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
150 changes: 150 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tough-ssm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl KeySource for SsmKeySource {
})?
.as_bytes()
.to_vec();
let sign = Box::new(parse_keypair(&data).context(error::KeyPairParseSnafu)?);
let sign = Box::new(parse_keypair(&data, None).context(error::KeyPairParseSnafu)?);
Ok(sign)
}

Expand Down
2 changes: 2 additions & 0 deletions tough/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ edition = "2018"
async-recursion = "1"
async-trait = "0.1"
aws-lc-rs = "1"
base64 = "0.22"
bytes = "1"
chrono = { version = "0.4", default-features = false, features = ["std", "alloc", "serde", "clock"] }
dyn-clone = "1"
Expand All @@ -23,6 +24,7 @@ log = "0.4"
olpc-cjson = { version = "0.1", path = "../olpc-cjson" }
pem = "3"
percent-encoding = "2"
pkcs8 = { version = "0.10", features = ["encryption", "pem", "std"] }
reqwest = { version = "0.12", optional = true, default-features = false, features = ["stream"] }
rustls = "0.23"
serde = { version = "1", features = ["derive"] }
Expand Down
10 changes: 8 additions & 2 deletions tough/src/editor/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ mod tests {
#[tokio::test]
async fn empty_repository() {
let root_key = key_path();
let key_source = LocalKeySource { path: root_key };
let key_source = LocalKeySource {
path: root_key,
password: None,
};
let root_path = root_path();

let editor = RepositoryEditor::new(root_path).await.unwrap();
Expand Down Expand Up @@ -112,7 +115,10 @@ mod tests {
async fn complete_repository() {
let root = root_path();
let root_key = key_path();
let key_source = LocalKeySource { path: root_key };
let key_source = LocalKeySource {
path: root_key,
password: None,
};
let timestamp_expiration = Utc::now().checked_add_signed(days(3)).unwrap();
let timestamp_version = NonZeroU64::new(1234).unwrap();
let snapshot_expiration = Utc::now().checked_add_signed(days(21)).unwrap();
Expand Down
5 changes: 4 additions & 1 deletion tough/src/key_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub trait KeySource: Debug + Send + Sync {
pub struct LocalKeySource {
/// The path to a local key file in PEM pkcs8 or RSA format.
pub path: PathBuf,
/// Optional password for the key file.
pub password: Option<String>,
}

/// Implements the `KeySource` trait for a `LocalKeySource` (file)
Expand All @@ -44,7 +46,8 @@ impl KeySource for LocalKeySource {
let data = tokio::fs::read(&self.path)
.await
.context(error::FileReadSnafu { path: &self.path })?;
Ok(Box::new(parse_keypair(&data)?))
let password: Option<&str> = self.password.as_deref();
Ok(Box::new(parse_keypair(&data, password)?))
}

async fn write(
Expand Down
43 changes: 36 additions & 7 deletions tough/src/sign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ use crate::sign::SignKeyPair::RSA;
use async_trait::async_trait;
use aws_lc_rs::rand::SecureRandom;
use aws_lc_rs::signature::{EcdsaKeyPair, Ed25519KeyPair, KeyPair, RsaKeyPair};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use pkcs8::der::Decode;
use snafu::ResultExt;
use std::collections::HashMap;
use std::error::Error;

use std::str;
/// This trait must be implemented for each type of key with which you will
/// sign things.
#[async_trait]
Expand Down Expand Up @@ -166,17 +168,44 @@ impl Sign for SignKeyPair {
}
}

/// Decrypts an RSA private key in PEM format using the given password.
/// Returns the decrypted key in PKCS8 format
pub fn decrypt_key(
encrypted_key: &[u8],
password: &str,
) -> std::result::Result<Vec<u8>, Box<dyn std::error::Error>> {
let pem_str = std::str::from_utf8(encrypted_key)?;
let pem = pem::parse(pem_str)?;
let encrypted_private_key_document = pkcs8::EncryptedPrivateKeyInfo::from_der(pem.contents())?;
let decrypted_private_key_document =
encrypted_private_key_document.decrypt(password.as_bytes())?;
let decrypted_key_bytes: Vec<u8> = decrypted_private_key_document.as_bytes().to_vec();
let decrypted_key_base64 = STANDARD.encode(decrypted_key_bytes);
let pem_key =
format!("-----BEGIN PRIVATE KEY-----\n{decrypted_key_base64}\n-----END PRIVATE KEY-----");
let pem_key_bytes = pem_key.as_bytes().to_vec();
Ok(pem_key_bytes)
}

/// Parses a supplied keypair and if it is recognized, returns an object that
/// implements the Sign trait
/// Accepted Keys: ED25519 pkcs8, Ecdsa pkcs8, RSA
pub fn parse_keypair(key: &[u8]) -> Result<impl Sign> {
if let Ok(ed25519_key_pair) = Ed25519KeyPair::from_pkcs8(key) {
pub fn parse_keypair(key: &[u8], password: Option<&str>) -> Result<impl Sign> {
let decrypted_key = if let Some(pw) = password {
decrypt_key(key, pw).unwrap_or_else(|_| key.to_vec())
} else {
key.to_vec()
};
let decrypted_key_slice: &[u8] = &decrypted_key;

if let Ok(ed25519_key_pair) = Ed25519KeyPair::from_pkcs8(decrypted_key_slice) {
Ok(SignKeyPair::ED25519(ed25519_key_pair))
} else if let Ok(ecdsa_key_pair) =
EcdsaKeyPair::from_pkcs8(&aws_lc_rs::signature::ECDSA_P256_SHA256_ASN1_SIGNING, key)
{
} else if let Ok(ecdsa_key_pair) = EcdsaKeyPair::from_pkcs8(
&aws_lc_rs::signature::ECDSA_P256_SHA256_ASN1_SIGNING,
decrypted_key_slice,
) {
Ok(SignKeyPair::ECDSA(ecdsa_key_pair))
} else if let Ok(pem) = pem::parse(key) {
} else if let Ok(pem) = pem::parse(decrypted_key_slice) {
match pem.tag() {
"PRIVATE KEY" => {
if let Ok(rsa_key_pair) = RsaKeyPair::from_pkcs8(pem.contents()) {
Expand Down
Loading