Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix panic if the quote doesn't contains pck cert chain #3

Merged
merged 2 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = ["Kevin Wang <[email protected]>"]
hex = { version = "0.4", default-features = false, features = ["alloc"] }
serde = { version = "1.0.215", default-features = false, features = ["derive"] }
base64 = { version = "0.22.1", default-features = false, features = ["alloc"] }
scale = { package = "parity-scale-codec", version = "3.7.0", default-features = false, features = [
scale = { package = "parity-scale-codec", version = "3.6.12", default-features = false, features = [
"derive",
] }
scale-info = { version = "2.11.6", default-features = false, features = ["derive"] }
Expand Down
28 changes: 15 additions & 13 deletions src/collateral.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use alloc::string::{String, ToString};
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use scale::Decode;

use crate::quote::Quote;
Expand All @@ -12,7 +12,7 @@ fn get_header(resposne: &reqwest::Response, name: &str) -> Result<String> {
let value = resposne
.headers()
.get(name)
.ok_or(anyhow!("Missing {name}"))?
.ok_or_else(|| anyhow!("Missing {name}"))?
.to_str()?;
let value = urlencoding::decode(value)?;
Ok(value.into_owned())
Expand All @@ -36,7 +36,7 @@ pub async fn get_collateral(
#[cfg(not(feature = "js"))] timeout: Duration,
) -> Result<QuoteCollateralV3> {
let quote = Quote::decode(&mut quote)?;
let fmspc = hex::encode_upper(quote.fmspc().map_err(|_| anyhow!("get fmspc error"))?);
let fmspc = hex::encode_upper(quote.fmspc().context("Failed to get FMSPC")?);
let builder = reqwest::Client::builder();
#[cfg(not(feature = "js"))]
let builder = builder.danger_accept_invalid_certs(true).timeout(timeout);
Expand All @@ -63,29 +63,31 @@ pub async fn get_collateral(
};

let tcb_info_json: serde_json::Value =
serde_json::from_str(&raw_tcb_info).map_err(|_| anyhow!("TCB Info should a JSON"))?;
serde_json::from_str(&raw_tcb_info).context("TCB Info should be valid JSON")?;
let tcb_info = tcb_info_json["tcbInfo"].to_string();
let tcb_info_signature = tcb_info_json
.get("signature")
.ok_or(anyhow!("TCB Info should has `signature` field"))?
.context("TCB Info missing 'signature' field")?
.as_str()
.ok_or(anyhow!("TCB Info signature should a hex string"))?;
.context("TCB Info signature must be a string")?;
let tcb_info_signature = hex::decode(tcb_info_signature)
.map_err(|_| anyhow!("TCB Info signature should a hex string"))?;
.ok()
.context("TCB Info signature must be valid hex")?;

let qe_identity_json: serde_json::Value = serde_json::from_str(raw_qe_identity.as_str())
.map_err(|_| anyhow!("QE Identity should a JSON"))?;
let qe_identity_json: serde_json::Value =
serde_json::from_str(&raw_qe_identity).context("QE Identity should be valid JSON")?;
let qe_identity = qe_identity_json
.get("enclaveIdentity")
.ok_or(anyhow!("QE Identity should has `enclaveIdentity` field"))?
.context("QE Identity missing 'enclaveIdentity' field")?
.to_string();
let qe_identity_signature = qe_identity_json
.get("signature")
.ok_or(anyhow!("QE Identity should has `signature` field"))?
.context("QE Identity missing 'signature' field")?
.as_str()
.ok_or(anyhow!("QE Identity signature should a hex string"))?;
.context("QE Identity signature must be a string")?;
let qe_identity_signature = hex::decode(qe_identity_signature)
.map_err(|_| anyhow!("QE Identity signature should a hex string"))?;
.ok()
.context("QE Identity signature must be valid hex")?;

Ok(QuoteCollateralV3 {
tcb_info_issuer_chain,
Expand Down
35 changes: 0 additions & 35 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,43 +39,8 @@
extern crate alloc;

use scale::{Decode, Encode};
use scale_info::TypeInfo;
use serde::{Deserialize, Serialize};

#[derive(Encode, Decode, TypeInfo, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Error {
InvalidCertificate,
InvalidSignature,
CodecError,

// DCAP
TCBInfoExpired,
KeyLengthIsInvalid,
PublicKeyIsInvalid,
RsaSignatureIsInvalid,
DerEncodingError,
UnsupportedDCAPQuoteVersion,
UnsupportedDCAPAttestationKeyType,
UnsupportedQuoteAuthData,
UnsupportedDCAPPckCertFormat,
LeafCertificateParsingError,
CertificateChainIsInvalid,
CertificateChainIsTooShort,
IntelExtensionCertificateDecodingError,
IntelExtensionAmbiguity,
CpuSvnLengthMismatch,
CpuSvnDecodingError,
PceSvnDecodingError,
PceSvnLengthMismatch,
FmspcLengthMismatch,
FmspcDecodingError,
FmspcMismatch,
QEReportHashMismatch,
IsvEnclaveReportSignatureIsInvalid,
DerDecodingError,
OidIsMissing,
}

#[derive(Encode, Decode, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct QuoteCollateralV3 {
pub tcb_info_issuer_chain: String,
Expand Down
36 changes: 22 additions & 14 deletions src/quote.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use alloc::string::String;
use alloc::vec::Vec;

use anyhow::Result;
use anyhow::{anyhow, bail, Context, Result};
use scale::{Decode, Input};
use serde::{Deserialize, Serialize};

use crate::{constants::*, utils, Error};
use crate::{constants::*, utils};

#[derive(Debug, Clone)]
pub struct Data<T> {
Expand Down Expand Up @@ -232,7 +232,7 @@ fn decode_auth_data(ver: u16, input: &mut &[u8]) -> Result<AuthData, scale::Erro
let auth_data = AuthDataV4::decode(input)?;
Ok(AuthData::V4(auth_data))
}
_ => Err(scale::Error::from("unsupported auth data version")),
_ => Err(scale::Error::from("Unsupported auth data version")),
}
}

Expand Down Expand Up @@ -296,7 +296,7 @@ impl Decode for Quote {
TEE_TYPE_TDX => {
report = Report::TD10(TDReport10::decode(input)?);
}
_ => return Err(scale::Error::from("invalid tee type")),
_ => return Err(scale::Error::from("Invalid TEE type")),
},
5 => {
let body = Body::decode(input)?;
Expand All @@ -310,10 +310,10 @@ impl Decode for Quote {
BODY_TD_REPORT15_TYPE => {
report = Report::TD15(TDReport15::decode(input)?);
}
_ => return Err(scale::Error::from("unsupported body type")),
_ => return Err(scale::Error::from("Unsupported body type")),
}
}
_ => return Err(scale::Error::from("unsupported quote version")),
_ => return Err(scale::Error::from("Unsupported quote version")),
}
let data = Data::<u32>::decode(input)?;
let auth_data = decode_auth_data(header.version, &mut &data.data[..])?;
Expand All @@ -334,18 +334,26 @@ impl Quote {
}

/// Get the raw certificate chain from the quote.
pub fn raw_cert_chain(&self) -> &[u8] {
match &self.auth_data {
AuthData::V3(data) => &data.certification_data.body.data,
AuthData::V4(data) => &data.qe_report_data.certification_data.body.data,
pub fn raw_cert_chain(&self) -> Result<&[u8]> {
let cert_data = match &self.auth_data {
AuthData::V3(data) => &data.certification_data,
AuthData::V4(data) => &data.qe_report_data.certification_data,
};
if cert_data.cert_type != 5 {
bail!("Unsupported cert type: {}", cert_data.cert_type);
}
Ok(&cert_data.body.data)
}

/// Get the FMSPC from the quote.
pub fn fmspc(&self) -> Result<Fmspc, Error> {
let raw_cert_chain = self.raw_cert_chain();
let certs = utils::extract_certs(raw_cert_chain)?;
let extension_section = utils::get_intel_extension(&certs[0])?;
pub fn fmspc(&self) -> Result<Fmspc> {
let raw_cert_chain = self
.raw_cert_chain()
.context("Failed to get raw cert chain")?;
let certs = utils::extract_certs(raw_cert_chain).context("Failed to extract certs")?;
let cert = certs.get(0).ok_or(anyhow!("Invalid certificate"))?;
let extension_section =
utils::get_intel_extension(cert).context("Failed to get Intel extension")?;
utils::get_fmspc(&extension_section)
}

Expand Down
89 changes: 45 additions & 44 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alloc::vec::Vec;
use anyhow::{anyhow, bail, Context, Result};
use asn1_der::{
typed::{DerDecodable, Sequence},
DerObject,
Expand All @@ -7,11 +8,10 @@ use webpki::types::CertificateDer;
use x509_cert::Certificate;

use crate::constants::*;
use crate::Error;

pub fn get_intel_extension(der_encoded: &[u8]) -> Result<Vec<u8>, Error> {
let cert: Certificate = der::Decode::from_der(der_encoded)
.map_err(|_| Error::IntelExtensionCertificateDecodingError)?;
pub fn get_intel_extension(der_encoded: &[u8]) -> Result<Vec<u8>> {
let cert: Certificate =
der::Decode::from_der(der_encoded).context("Failed to decode certificate")?;
let mut extension_iter = cert
.tbs_certificate
.extensions
Expand All @@ -21,88 +21,89 @@ pub fn get_intel_extension(der_encoded: &[u8]) -> Result<Vec<u8>, Error> {
.filter(|e| e.extn_id == oids::SGX_EXTENSION)
.map(|e| e.extn_value.clone());

let extension = extension_iter
.next()
.ok_or(Error::IntelExtensionAmbiguity)?;
let extension = extension_iter.next().context("Intel extension not found")?;
if extension_iter.next().is_some() {
//"There should only be one section containing Intel extensions"
return Err(Error::IntelExtensionAmbiguity);
bail!("Intel extension ambiguity");
}
Ok(extension.into_bytes())
}

pub fn find_extension(path: &[&[u8]], raw: &[u8]) -> Result<Vec<u8>, Error> {
let obj = DerObject::decode(raw).map_err(|_| Error::DerDecodingError)?;
let subobj = get_obj(path, obj)?;
pub fn find_extension(path: &[&[u8]], raw: &[u8]) -> Result<Vec<u8>> {
let obj = DerObject::decode(raw).context("Failed to decode DER object")?;
let subobj = get_obj(path, obj).context("Failed to get subobject")?;
Ok(subobj.value().to_vec())
}

fn get_obj<'a>(path: &[&[u8]], mut obj: DerObject<'a>) -> Result<DerObject<'a>, Error> {
fn get_obj<'a>(path: &[&[u8]], mut obj: DerObject<'a>) -> Result<DerObject<'a>> {
for oid in path {
let seq = Sequence::load(obj).map_err(|_| Error::DerDecodingError)?;
obj = sub_obj(oid, seq)?;
let seq = Sequence::load(obj).context("Failed to load sequence")?;
obj = sub_obj(oid, seq).context("Failed to get subobject")?;
}
Ok(obj)
}

fn sub_obj<'a>(oid: &[u8], seq: Sequence<'a>) -> Result<DerObject<'a>, Error> {
fn sub_obj<'a>(oid: &[u8], seq: Sequence<'a>) -> Result<DerObject<'a>> {
for i in 0..seq.len() {
let entry = seq.get(i).map_err(|_| Error::OidIsMissing)?;
let entry = Sequence::load(entry).map_err(|_| Error::DerDecodingError)?;
let name = entry.get(0).map_err(|_| Error::OidIsMissing)?;
let value = entry.get(1).map_err(|_| Error::OidIsMissing)?;
let entry = seq.get(i).context("Failed to get entry")?;
let entry = Sequence::load(entry).context("Failed to load sequence")?;
let name = entry.get(0).context("Failed to get name")?;
let value = entry.get(1).context("Failed to get value")?;
if name.value() == oid {
return Ok(value);
}
}
Err(Error::OidIsMissing)
bail!("Oid is missing");
}

pub fn get_fmspc(extension_section: &[u8]) -> Result<Fmspc, Error> {
let data = find_extension(&[oids::FMSPC.as_bytes()], extension_section)?;
pub fn get_fmspc(extension_section: &[u8]) -> Result<Fmspc> {
let data = find_extension(&[oids::FMSPC.as_bytes()], extension_section)
.context("Failed to find Fmspc")?;
if data.len() != 6 {
return Err(Error::FmspcLengthMismatch);
bail!("Fmspc length mismatch");
}

data.try_into().map_err(|_| Error::FmspcDecodingError)
data.try_into()
.map_err(|_| anyhow!("Failed to decode Fmspc"))
}

pub fn get_cpu_svn(extension_section: &[u8]) -> Result<CpuSvn, Error> {
pub fn get_cpu_svn(extension_section: &[u8]) -> Result<CpuSvn> {
let data = find_extension(
&[oids::TCB.as_bytes(), oids::CPUSVN.as_bytes()],
extension_section,
)?;
if data.len() != 16 {
return Err(Error::CpuSvnLengthMismatch);
bail!("CpuSvn length mismatch");
}

data.try_into().map_err(|_| Error::CpuSvnDecodingError)
data.try_into().map_err(|_| anyhow!("Failed to decode CpuSvn"))
}

pub fn get_pce_svn(extension_section: &[u8]) -> Result<Svn, Error> {
pub fn get_pce_svn(extension_section: &[u8]) -> Result<Svn> {
let data = find_extension(
&[oids::TCB.as_bytes(), oids::PCESVN.as_bytes()],
extension_section,
)?;
)
.context("Failed to find PceSvn")?;

match data.len() {
1 => Ok(u16::from(data[0])),
2 => Ok(u16::from_be_bytes(
data.try_into().map_err(|_| Error::PceSvnDecodingError)?,
data.try_into().map_err(|_| anyhow!("Failed to decode PceSvn"))?,
)),
_ => Err(Error::PceSvnLengthMismatch),
_ => bail!("PceSvn length mismatch"),
}
}

pub fn extract_raw_certs(cert_chain: &[u8]) -> Result<Vec<Vec<u8>>, Error> {
pub fn extract_raw_certs(cert_chain: &[u8]) -> Result<Vec<Vec<u8>>> {
Ok(pem::parse_many(cert_chain)
.map_err(|_| Error::CodecError)?
.context("Failed to parse certs")?
.iter()
.map(|i| i.contents().to_vec())
.collect())
}

pub fn extract_certs<'a>(cert_chain: &'a [u8]) -> Result<Vec<CertificateDer<'a>>, Error> {
pub fn extract_certs<'a>(cert_chain: &'a [u8]) -> Result<Vec<CertificateDer<'a>>> {
let mut certs = Vec::<CertificateDer<'a>>::new();

let raw_certs = extract_raw_certs(cert_chain)?;
Expand All @@ -117,26 +118,26 @@ pub fn extract_certs<'a>(cert_chain: &'a [u8]) -> Result<Vec<CertificateDer<'a>>
/// Encode two 32-byte values in DER format
/// This is meant for 256 bit ECC signatures or public keys
/// TODO: We may could use `asn1_der` crate reimplement this, so we can remove `der` which overlap with `asn1_der`
pub fn encode_as_der(data: &[u8]) -> Result<Vec<u8>, Error> {
pub fn encode_as_der(data: &[u8]) -> Result<Vec<u8>> {
if data.len() != 64 {
return Err(Error::KeyLengthIsInvalid);
bail!("Key length is invalid");
}
let mut sequence = der::asn1::SequenceOf::<der::asn1::UintRef, 2>::new();
sequence
.add(der::asn1::UintRef::new(&data[0..32]).map_err(|_| Error::PublicKeyIsInvalid)?)
.map_err(|_| Error::PublicKeyIsInvalid)?;
.add(der::asn1::UintRef::new(&data[0..32]).context("Failed to add first element")?)
.context("Failed to add second element")?;
sequence
.add(der::asn1::UintRef::new(&data[32..]).map_err(|_| Error::PublicKeyIsInvalid)?)
.map_err(|_| Error::PublicKeyIsInvalid)?;
.add(der::asn1::UintRef::new(&data[32..]).context("Failed to add third element")?)
.context("Failed to add third element")?;
// 72 should be enough in all cases. 2 + 2 x (32 + 3)
let mut asn1 = alloc::vec![0u8; 72];
let mut writer = der::SliceWriter::new(&mut asn1);
writer
.encode(&sequence)
.map_err(|_| Error::DerEncodingError)?;
.context("Failed to encode sequence")?;
Ok(writer
.finish()
.map_err(|_| Error::DerEncodingError)?
.context("Failed to finish writer")?
.to_vec())
}

Expand All @@ -146,7 +147,7 @@ pub fn verify_certificate_chain(
leaf_cert: &webpki::EndEntityCert,
intermediate_certs: &[CertificateDer],
verification_time: u64,
) -> Result<(), Error> {
) -> Result<()> {
let time = webpki::types::UnixTime::since_unix_epoch(core::time::Duration::from_secs(
verification_time / 1000,
));
Expand All @@ -161,7 +162,7 @@ pub fn verify_certificate_chain(
None,
None,
)
.map_err(|_e| Error::CertificateChainIsInvalid)?;
.context("Failed to verify certificate chain")?;

Ok(())
}
Loading
Loading