diff --git a/matrix/requests_signing.go b/matrix/requests_signing.go index 8694139f..57bb1677 100644 --- a/matrix/requests_signing.go +++ b/matrix/requests_signing.go @@ -20,7 +20,7 @@ type signingKey struct { Key string `json:"key"` } -type serverKeyResult struct { +type ServerKeyResult struct { ServerName string `json:"server_name"` ValidUntilTs int64 `json:"valid_until_ts"` VerifyKeys map[string]signingKey `json:"verify_keys"` // unpadded base64 @@ -83,7 +83,7 @@ func QuerySigningKeys(serverName string) (ServerSigningKeys, error) { if err = decoder.Decode(&raw); err != nil { return nil, err } - keyInfo := new(serverKeyResult) + keyInfo := new(ServerKeyResult) if err = raw.ApplyTo(keyInfo); err != nil { return nil, err } @@ -100,51 +100,66 @@ func QuerySigningKeys(serverName string) (ServerSigningKeys, error) { return nil, errors.New("returned server keys would expire too quickly") } - // Convert to something useful - serverKeys := make(ServerSigningKeys) - for keyId, keyObj := range keyInfo.VerifyKeys { - b, err := util.DecodeUnpaddedBase64String(keyObj.Key) - if err != nil { - return nil, errors.Join(fmt.Errorf("bad base64 for key ID '%s' for '%s'", keyId, serverName), err) - } - - serverKeys[keyId] = b + // Convert keys to something useful, and check signatures + serverKeys, err := CheckSigningKeySignatures(serverName, keyInfo, raw) + if err != nil { + return nil, err } - // Check signatures - if len(keyInfo.Signatures) == 0 || len(keyInfo.Signatures[serverName]) == 0 { - return nil, fmt.Errorf("missing signatures from '%s'", serverName) - } - delete(raw, "signatures") - canonical, err := util.EncodeCanonicalJson(raw) + // Cache & return (unlock was deferred) + signingKeyCache.Set(serverName, &serverKeys, cacheUntil) + return serverKeys, nil + }) + return keys, err +} + +func CheckSigningKeySignatures(serverName string, keyInfo *ServerKeyResult, raw database.AnonymousJson) (ServerSigningKeys, error) { + serverKeys := make(ServerSigningKeys) + for keyId, keyObj := range keyInfo.VerifyKeys { + b, err := util.DecodeUnpaddedBase64String(keyObj.Key) if err != nil { - return nil, err + return nil, errors.Join(fmt.Errorf("bad base64 for key ID '%s' for '%s'", keyId, serverName), err) } - for domain, sig := range keyInfo.Signatures { - if domain != serverName { - return nil, fmt.Errorf("unexpected signature from '%s' (expected '%s')", domain, serverName) - } - for keyId, b64 := range sig { - signatureBytes, err := util.DecodeUnpaddedBase64String(b64) - if err != nil { - return nil, errors.Join(fmt.Errorf("bad base64 signature for key ID '%s' for '%s'", keyId, serverName), err) - } + serverKeys[keyId] = b + } - key, ok := serverKeys[keyId] - if !ok { - return nil, fmt.Errorf("unknown key ID '%s' for signature from '%s'", keyId, serverName) - } + if len(keyInfo.Signatures) == 0 || len(keyInfo.Signatures[serverName]) == 0 { + return nil, fmt.Errorf("missing signatures from '%s'", serverName) + } + delete(raw, "signatures") + canonical, err := util.EncodeCanonicalJson(raw) + if err != nil { + return nil, err + } + for domain, sig := range keyInfo.Signatures { + if domain != serverName { + return nil, fmt.Errorf("unexpected signature from '%s' (expected '%s')", domain, serverName) + } + + for keyId, b64 := range sig { + signatureBytes, err := util.DecodeUnpaddedBase64String(b64) + if err != nil { + return nil, errors.Join(fmt.Errorf("bad base64 signature for key ID '%s' for '%s'", keyId, serverName), err) + } + + key, ok := serverKeys[keyId] + if !ok { + return nil, fmt.Errorf("unknown key ID '%s' for signature from '%s'", keyId, serverName) + } - if !ed25519.Verify(key, canonical, signatureBytes) { - return nil, fmt.Errorf("invalid signature '%s' from key ID '%s' for '%s'", b64, keyId, serverName) - } + if !ed25519.Verify(key, canonical, signatureBytes) { + return nil, fmt.Errorf("invalid signature '%s' from key ID '%s' for '%s'", b64, keyId, serverName) } } + } + + // Ensure *all* keys have signed the response + for keyId, _ := range serverKeys { + if _, ok := keyInfo.Signatures[serverName][keyId]; !ok { + return nil, fmt.Errorf("missing signature from key '%s'", keyId) + } + } - // Cache & return (unlock was deferred) - signingKeyCache.Set(serverName, &serverKeys, cacheUntil) - return serverKeys, nil - }) - return keys, err + return serverKeys, nil } diff --git a/test/signing_keys_test.go b/test/signing_keys_test.go new file mode 100644 index 00000000..13edf013 --- /dev/null +++ b/test/signing_keys_test.go @@ -0,0 +1,73 @@ +package test + +import ( + "crypto/ed25519" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/matrix" + "github.com/turt2live/matrix-media-repo/util" +) + +func TestFailInjectedKeys(t *testing.T) { + raw := database.AnonymousJson{ + "old_verify_keys": database.AnonymousJson{}, + "server_name": "x.resolvematrix.dev", + "signatures": database.AnonymousJson{ + "x.resolvematrix.dev": database.AnonymousJson{ + "ed25519:injected": "FB93YAF+fOPyWcsx285Q/xFzRiG5sr7/u1iX9XWaIcOwDyDDwx7daS1eYxuM9PfosWE5vqUyTsCxmB40JTzdCw", + }, + }, + "valid_until_ts": 1701055573679, + "verify_keys": database.AnonymousJson{ + "ed25519:AY4k3ADlto8": database.AnonymousJson{"key": "VF7dl9W/tFWAjZSXm42Ef22k3v4WKBYLXZF9I7ErU00"}, + "ed25519:injected": database.AnonymousJson{"key": "w48CLiV1IkWoEbqJLFmniGUYtxwT+c2zm87X8oEpRO8"}, + }, + } + keyInfo := new(matrix.ServerKeyResult) + err := raw.ApplyTo(keyInfo) + if err != nil { + t.Fatal(err) + } + + _, err = matrix.CheckSigningKeySignatures("x.resolvematrix.dev", keyInfo, raw) + assert.Error(t, err) + assert.Equal(t, "missing signature from key 'ed25519:AY4k3ADlto8'", err.Error()) +} + +func TestRegularKeys(t *testing.T) { + raw := database.AnonymousJson{ + "old_verify_keys": database.AnonymousJson{}, + "server_name": "x.resolvematrix.dev", + "signatures": database.AnonymousJson{ + "x.resolvematrix.dev": database.AnonymousJson{ + "ed25519:AY4k3ADlto8": "3WlsmHFTVjywCoDYyrtx3ies+VufTuBuw1Prlgmoqh+a4XrJT+isEwhTX+I5FBvtJTKTt6vLH3gaP7BA6712CA", + }, + }, + "valid_until_ts": 1701057124839, + "verify_keys": database.AnonymousJson{ + "ed25519:AY4k3ADlto8": database.AnonymousJson{"key": "VF7dl9W/tFWAjZSXm42Ef22k3v4WKBYLXZF9I7ErU00"}, + }, + } + keyInfo := new(matrix.ServerKeyResult) + err := raw.ApplyTo(keyInfo) + if err != nil { + t.Fatal(err) + } + + keys, err := matrix.CheckSigningKeySignatures("x.resolvematrix.dev", keyInfo, raw) + assert.NoError(t, err) + for keyId, keyVal := range keys { + if b64, ok := keyInfo.VerifyKeys[keyId]; !ok { + t.Errorf("got key for '%s' but wasn't expecting it", keyId) + } else { + keySelf, err := util.DecodeUnpaddedBase64String(b64.Key) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, ed25519.PublicKey(keySelf), keyVal) + } + } +}