Skip to content

Commit

Permalink
Clean up code, add subpackages
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperMalachowski committed Sep 18, 2024
1 parent fd5f94e commit a6ace7c
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 137 deletions.
1 change: 1 addition & 0 deletions cmd/cloud-run/signifysecretrotator/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ WORKDIR /app
COPY ./cmd/cloud-run/signifysecretrotator/**/*.py .
COPY ./cmd/cloud-run/signifysecretrotator/requirements.txt .

RUN apk add g++
RUN pip install --no-cache-dir --upgrade -r requirements.txt && \
apk add --no-cache ca-certificates

Expand Down
Binary file not shown.
3 changes: 1 addition & 2 deletions cmd/cloud-run/signifysecretrotator/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
flask==2.3.2
cloudevents==1.9.0
gunicorn==22.0.0
google-cloud-secret-manager==2.20.2
cryptography==43.0.1
google-cloud-secret-manager==2.20.2
Empty file.
123 changes: 123 additions & 0 deletions cmd/cloud-run/signifysecretrotator/signify/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Custom client for signify API"""

import json
import tempfile
import requests
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.serialization import pkcs7


class SignifyClient:
"""Wraps signify API"""

def __init__(
self, token_url: str, certificate_service_url: str, client_id: str
) -> None:
self.token_url = token_url
self.certificate_service_url = certificate_service_url
self.client_id = client_id

def fetch_access_token(self, certificate: bytes, private_key: bytes) -> str:
"""fetches access token from given token_url using certificate and private key"""
# Use temporary file for old cert and key because requests library needs file paths,
# the code is running in known environment controlled by us
with (
tempfile.NamedTemporaryFile() as old_cert_file,
tempfile.NamedTemporaryFile() as old_key_file,
):

old_cert_file.write(certificate)
old_cert_file.flush()

old_key_file.write(private_key)
old_key_file.flush()

access_token_response = requests.post(
self.token_url,
cert=(old_cert_file.name, old_key_file.name),
data={
"grant_type": "client_credentials",
"client_id": self.client_id,
},
timeout=30,
)

if access_token_response.status_code != 200:
raise requests.HTTPError(
f"Got not-success status code {access_token_response.status_code}",
response=access_token_response,
)

decoded_response = access_token_response.json()

if "access_token" not in decoded_response:
raise ValueError(
f"Got unexpected response structure: {decoded_response}"
)

return decoded_response["access_token"]

def fetch_new_certificate(
self, cert_data: bytes, private_key: rsa.RSAPrivateKey, access_token: str
):
"""Fetch new certificates from given certificate service"""

csr = self._prepare_csr(cert_data, private_key)

crt_create_payload = self._prepare_cert_request_paylaod(csr)

cert_create_response = requests.post(
self.certificate_service_url,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"Accept": "application/json",
},
data=crt_create_payload,
timeout=10,
)

if cert_create_response.status_code != 200:
raise requests.HTTPError(
f"Got un-success statsu code {cert_create_response.status_code}"
)

decoded_response = cert_create_response.json()

if (
"certificateChain" not in decoded_response
or "value" not in decoded_response["certificateChain"]
):
raise ValueError(
f"Cannot issue new certifacte, invalid response format: {decoded_response}"
)

pkcs7_certs = decoded_response["certificateChain"]["value"].encode()

return pkcs7.load_pem_pkcs7_certificates(pkcs7_certs)

def _prepare_cert_request_paylaod(self, csr: x509.CertificateSigningRequest):
return json.dumps(
{
"csr": {
"value": csr.public_bytes(serialization.Encoding.PEM).decode(
"utf-8"
)
},
"validity": {"value": 7, "type": "DAYS"},
"policy": "sap-cloud-platform-clients",
}
)

def _prepare_csr(self, cert_data: bytes, private_key: rsa.RSAPrivateKey):
old_cert = x509.load_pem_x509_certificate(cert_data)

csr = (
x509.CertificateSigningRequestBuilder()
.subject_name(old_cert.subject)
.sign(private_key, hashes.SHA256())
)

return csr
141 changes: 141 additions & 0 deletions cmd/cloud-run/signifysecretrotator/signify/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Tests for signify client module"""

import base64
import unittest
from unittest.mock import patch, MagicMock

import requests

# pylint: disable=import-error
# False positive see: https://github.com/pylint-dev/pylint/issues/3984
from client import SignifyClient
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat

# pylint: disable=import-error
# False positive see: https://github.com/pylint-dev/pylint/issues/3984
import test_fixtures


class TestSignifyClient(unittest.TestCase):
"""
Unit tests for the SignifyClient class.
"""

def setUp(self):
"""
Set up method to initialize the SignifyClient object and necessary data for tests.
"""
self.token_url = "https://example.com/token"
self.certificate_service_url = "https://example.com/certificate"
self.client = SignifyClient(
token_url=self.token_url,
certificate_service_url=self.certificate_service_url,
client_id="fake_client_id",
)
self.certificate = base64.b64decode(
test_fixtures.mocked_secret_data["certData"]
)
self.private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
self.access_token = "fake_access_token"

@patch("requests.post")
def test_fetch_access_token_success(self, mock_post):
"""
Test successful fetch of access token.
Mock the requests.post to return a successful response with the expected access token.
"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"access_token": self.access_token}
mock_post.return_value = mock_response
private_key_bytes = self.private_key.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, serialization.NoEncryption()
)

token = self.client.fetch_access_token(
self.certificate,
private_key_bytes,
)

self.assertEqual(token, self.access_token)
mock_post.assert_called_once()

@patch("requests.post")
def test_fetch_access_token_failed_status_code(self, mock_post):
"""
Test fetch of access token when the response status code is not 200.
"""
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {}
mock_post.return_value = mock_response
private_key_bytes = self.private_key.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, serialization.NoEncryption()
)

with self.assertRaises(requests.HTTPError):
self.client.fetch_access_token(
certificate=self.certificate, private_key=private_key_bytes
)
mock_post.assert_called_once()

@patch("requests.post")
def test_fetch_access_token_unexpected_response(self, mock_post):
"""
Test fetch of access token when the response does not contain the expected structure.
"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {}
mock_post.return_value = mock_response
private_key_bytes = self.private_key.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, serialization.NoEncryption()
)

with self.assertRaises(ValueError):
self.client.fetch_access_token(
certificate=self.certificate, private_key=private_key_bytes
)
mock_post.assert_called_once()

@patch("requests.post")
def test_fetch_new_certificate_success(self, mock_post):
"""
Test successful fetch of a new certificate.
"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = test_fixtures.mocked_cert_create_response
mock_post.return_value = mock_response

certs = self.client.fetch_new_certificate(
self.certificate, self.private_key, self.access_token
)

self.assertEqual(len(certs), 1)
mock_post.assert_called_once()

@patch("requests.post")
def test_fetch_new_certificate_failed_status_code(self, mock_post):
"""
Test fetch of a new certificate when the response status code is not 200.s
"""
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {}
mock_post.return_value = mock_response

with self.assertRaises(KeyError):
self.client.fetch_new_certificate(
self.certificate, self.private_key, self.access_token
)
mock_post.assert_called_once()


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit a6ace7c

Please sign in to comment.