-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
341 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
Copyright © 2025 Acronis International GmbH. | ||
Released under MIT license. | ||
*/ | ||
|
||
// Package jwk provides JSON Web Key (JWK) structure and methods to decode it to public and private keys. | ||
package jwk | ||
|
||
import ( | ||
"crypto" | ||
"crypto/rsa" | ||
"encoding/base64" | ||
"encoding/binary" | ||
"errors" | ||
"fmt" | ||
"math/big" | ||
) | ||
|
||
const typeRSA = "RSA" | ||
|
||
var supportedKeyTypes = map[string]struct{}{ | ||
typeRSA: {}, | ||
} | ||
|
||
// Key defines JSON Web Key structure. | ||
type Key struct { | ||
Alg string `json:"alg"` // algorithm | ||
Crv string `json:"crv,omitempty"` // curve - for EC keys | ||
D string `json:"d"` // private exponent | ||
DP string `json:"dp"` // d mod (p-1) | ||
DQ string `json:"dq"` // d mod (p-1) | ||
E string `json:"e"` // public exponent | ||
K string `json:"k,omitempty"` // symmetric key | ||
Kid string `json:"kid"` // Key ID | ||
Kty string `json:"kty"` // Key Type | ||
N string `json:"n"` // modulus | ||
P string `json:"p"` // prime factor 1 | ||
Q string `json:"q"` // prime factor 2 | ||
QI string `json:"qi"` // q^(-1) mod p | ||
Use string `json:"use"` | ||
X string `json:"x,omitempty"` // x coordinate - for EC keys | ||
Y string `json:"y,omitempty"` // y coordinate - for EC keys | ||
} | ||
|
||
// DecodePublicKey decodes Key to public key. | ||
func (j *Key) DecodePublicKey() (crypto.PublicKey, error) { | ||
if _, ok := supportedKeyTypes[j.Kty]; !ok { | ||
return nil, fmt.Errorf("unsupported key type %s", j.Kty) | ||
} | ||
|
||
var result interface{} | ||
|
||
if j.Kty == typeRSA { | ||
if j.N == "" || j.E == "" { | ||
return nil, errors.New("malformed JWK RSA key: missing N or E") | ||
} | ||
|
||
e, err := decodeBase64URLToBigInt(j.E) | ||
if err != nil { | ||
return nil, errors.New("malformed JWK RSA key") | ||
} | ||
eBytes := e.Bytes() | ||
if len(eBytes) < 4 { | ||
ndata := make([]byte, 4) | ||
copy(ndata[4-len(eBytes):], eBytes) | ||
eBytes = ndata | ||
} | ||
|
||
pubKey := &rsa.PublicKey{ | ||
N: &big.Int{}, | ||
E: int(binary.BigEndian.Uint32(eBytes)), | ||
} | ||
|
||
n, err := decodeBase64URLToBigInt(j.N) | ||
if err != nil { | ||
return nil, errors.New("malformed JWK RSA key") | ||
} | ||
pubKey.N = n | ||
|
||
result = pubKey | ||
} | ||
|
||
return result, nil | ||
} | ||
|
||
// DecodePrivateKey decodes Key to private key. | ||
func (j *Key) DecodePrivateKey() (crypto.PrivateKey, error) { | ||
if _, ok := supportedKeyTypes[j.Kty]; !ok { | ||
return nil, fmt.Errorf("unsupported key type %s", j.Kty) | ||
} | ||
|
||
var result interface{} | ||
var err error | ||
|
||
if j.Kty == typeRSA { | ||
if j.D == "" { | ||
return nil, errors.New("malformed JWK RSA private exponent") | ||
} | ||
|
||
// Decode base64url-encoded Key components | ||
components := []string{j.N, j.E, j.D, j.P, j.Q, j.DP, j.DQ, j.QI} | ||
decodedComponents := make([]*big.Int, len(components)) | ||
|
||
for i, component := range components { | ||
decodedComponents[i], err = decodeBase64URLToBigInt(component) | ||
if err != nil { | ||
return nil, fmt.Errorf("malformed Key RSA component: %w", err) | ||
} | ||
} | ||
|
||
n := decodedComponents[0] | ||
e := decodedComponents[1] | ||
d := decodedComponents[2] | ||
p := decodedComponents[3] | ||
q := decodedComponents[4] | ||
dp := decodedComponents[5] | ||
dq := decodedComponents[6] | ||
qi := decodedComponents[7] | ||
|
||
// Convert Key to *rsa.PrivateKey. | ||
rsaPrivateKey := &rsa.PrivateKey{ | ||
PublicKey: rsa.PublicKey{ | ||
N: n, | ||
E: int(e.Int64()), | ||
}, | ||
D: d, | ||
Primes: []*big.Int{p, q}, | ||
Precomputed: rsa.PrecomputedValues{ | ||
Dp: dp, | ||
Dq: dq, | ||
Qinv: qi, | ||
}, | ||
} | ||
|
||
rsaPrivateKey.Precompute() | ||
|
||
err = rsaPrivateKey.Validate() | ||
if err != nil { | ||
return nil, fmt.Errorf("invalid RSA private key: %w", err) | ||
} | ||
|
||
result = rsaPrivateKey | ||
} | ||
|
||
return result, err | ||
} | ||
|
||
// decodeBase64URLToBigInt is a helper function to decode base64url without padding. | ||
func decodeBase64URLToBigInt(encoded string) (*big.Int, error) { | ||
data, err := base64.RawURLEncoding.DecodeString(encoded) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to decode base64url: %w", err) | ||
} | ||
return new(big.Int).SetBytes(data), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package jwk_test | ||
|
||
import ( | ||
"crypto/rsa" | ||
"encoding/base64" | ||
"math/big" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/acronis/go-authkit/internal/jwk" | ||
) | ||
|
||
func encodeBigIntToBase64URL(n *big.Int) string { | ||
return base64.RawURLEncoding.EncodeToString(n.Bytes()) | ||
} | ||
|
||
func TestDecodePublicKey(t *testing.T) { | ||
// Define meaningful primes | ||
p := big.NewInt(61) | ||
q := big.NewInt(53) | ||
|
||
// Compute modulus n = p * q | ||
n := new(big.Int).Mul(p, q) // n = 61 × 53 = 3233 | ||
e := big.NewInt(65537) // common public exponent | ||
|
||
key := &jwk.Key{ | ||
Kty: "RSA", | ||
N: encodeBigIntToBase64URL(n), | ||
E: encodeBigIntToBase64URL(e), | ||
} | ||
|
||
pubKey, err := key.DecodePublicKey() | ||
require.NoError(t, err) | ||
require.NotNil(t, pubKey) | ||
require.IsType(t, &rsa.PublicKey{}, pubKey) | ||
require.Equal(t, n, pubKey.(*rsa.PublicKey).N) | ||
require.Equal(t, int(e.Int64()), pubKey.(*rsa.PublicKey).E) | ||
} | ||
|
||
func TestDecodePublicKeyFails(t *testing.T) { | ||
key := &jwk.Key{ | ||
Kty: "MEOW", | ||
N: "invalid", | ||
E: "invalid", | ||
} | ||
|
||
pubKey, err := key.DecodePublicKey() | ||
require.Error(t, err) | ||
require.ErrorContains(t, err, "unsupported key type") | ||
require.Nil(t, pubKey) | ||
|
||
key = &jwk.Key{ | ||
Kty: "RSA", | ||
} | ||
|
||
pubKey, err = key.DecodePublicKey() | ||
require.Error(t, err, "N and E are missing") | ||
require.ErrorContains(t, err, "malformed JWK RSA key") | ||
require.Nil(t, pubKey) | ||
|
||
key = &jwk.Key{ | ||
Kty: "RSA", | ||
N: "invalid", | ||
E: "!!invalid!!", | ||
} | ||
|
||
pubKey, err = key.DecodePublicKey() | ||
require.Error(t, err, "E is invalid") | ||
require.ErrorContains(t, err, "malformed JWK RSA key") | ||
require.Nil(t, pubKey) | ||
|
||
key = &jwk.Key{ | ||
Kty: "RSA", | ||
N: "!!invalid!!", | ||
E: "invalid", | ||
} | ||
|
||
pubKey, err = key.DecodePublicKey() | ||
require.Error(t, err, "N is invalid") | ||
require.ErrorContains(t, err, "malformed JWK RSA key") | ||
require.Nil(t, pubKey) | ||
} | ||
|
||
func TestDecodePrivateKey(t *testing.T) { | ||
p := big.NewInt(11) | ||
q := big.NewInt(13) | ||
n := new(big.Int).Mul(p, q) // modulus calc: n = p * q | ||
|
||
e := big.NewInt(65537) // Common public exponent | ||
phi := new(big.Int).Mul( | ||
new(big.Int).Sub(p, big.NewInt(1)), | ||
new(big.Int).Sub(q, big.NewInt(1)), | ||
) // phi(n) | ||
|
||
d := new(big.Int).ModInverse(e, phi) // Compute private exponent | ||
if d == nil { | ||
t.Fatal("Failed to compute modular inverse for d") | ||
} | ||
|
||
dp := new(big.Int).Mod(d, new(big.Int).Sub(p, big.NewInt(1))) // dp = d mod (p-1) | ||
dq := new(big.Int).Mod(d, new(big.Int).Sub(q, big.NewInt(1))) // dq = d mod (q-1) | ||
qi := new(big.Int).ModInverse(q, p) // qi = q^(-1) mod p | ||
if qi == nil { | ||
t.Fatal("Failed to compute modular inverse for qi") | ||
} | ||
|
||
key := &jwk.Key{ | ||
Kty: "RSA", | ||
N: base64.RawURLEncoding.EncodeToString(n.Bytes()), | ||
E: encodeBigIntToBase64URL(e), | ||
D: encodeBigIntToBase64URL(d), | ||
P: base64.RawURLEncoding.EncodeToString(p.Bytes()), | ||
Q: base64.RawURLEncoding.EncodeToString(q.Bytes()), | ||
DP: encodeBigIntToBase64URL(dp), | ||
DQ: encodeBigIntToBase64URL(dq), | ||
QI: encodeBigIntToBase64URL(qi), | ||
} | ||
|
||
privKey, err := key.DecodePrivateKey() | ||
require.NoError(t, err) | ||
require.NotNil(t, privKey) | ||
require.IsType(t, &rsa.PrivateKey{}, privKey) | ||
require.Equal(t, n, privKey.(*rsa.PrivateKey).PublicKey.N) | ||
require.Equal(t, int(e.Int64()), privKey.(*rsa.PrivateKey).PublicKey.E) | ||
} | ||
|
||
func TestDecodePrivateKeyFails(t *testing.T) { | ||
key := &jwk.Key{ | ||
Kty: "MEOW", | ||
} | ||
|
||
privKey, err := key.DecodePrivateKey() | ||
require.Error(t, err, "unsupported key type") | ||
require.ErrorContains(t, err, "unsupported key type") | ||
require.Nil(t, privKey) | ||
|
||
n := big.NewInt(111) // bad modulus, not a mul of p and q | ||
|
||
key = &jwk.Key{ | ||
Kty: "RSA", | ||
N: base64.RawURLEncoding.EncodeToString(n.Bytes()), | ||
P: base64.RawURLEncoding.EncodeToString(big.NewInt(11).Bytes()), | ||
Q: base64.RawURLEncoding.EncodeToString(big.NewInt(13).Bytes()), | ||
} | ||
|
||
privKey, err = key.DecodePrivateKey() | ||
require.Error(t, err, "bad modulus triggers crypto error") | ||
require.ErrorContains(t, err, "malformed JWK RSA private exponent") | ||
require.Nil(t, privKey) | ||
|
||
key = &jwk.Key{ | ||
Kty: "RSA", | ||
N: base64.RawURLEncoding.EncodeToString(n.Bytes()), | ||
P: base64.RawURLEncoding.EncodeToString(big.NewInt(11).Bytes()), | ||
Q: base64.RawURLEncoding.EncodeToString(big.NewInt(13).Bytes()), | ||
D: "!!invalid!!", | ||
} | ||
|
||
privKey, err = key.DecodePrivateKey() | ||
require.Error(t, err, "malformed D") | ||
require.ErrorContains(t, err, "malformed Key RSA component") | ||
require.Nil(t, privKey) | ||
|
||
key = &jwk.Key{ | ||
Kty: "RSA", | ||
N: base64.RawURLEncoding.EncodeToString(n.Bytes()), | ||
P: base64.RawURLEncoding.EncodeToString(big.NewInt(11).Bytes()), | ||
Q: base64.RawURLEncoding.EncodeToString(big.NewInt(13).Bytes()), | ||
D: "asdasd", | ||
} | ||
|
||
privKey, err = key.DecodePrivateKey() | ||
require.Error(t, err, "exported D is not a valid base64url") | ||
require.ErrorContains(t, err, "public exponent too small") | ||
require.Nil(t, privKey) | ||
} |
Oops, something went wrong.