Skip to content

Commit

Permalink
Test X-Matrix auth header stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Nov 26, 2023
1 parent 590c475 commit ecddf3f
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 29 deletions.
9 changes: 9 additions & 0 deletions common/config/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ func GetDomain(domain string) *DomainRepoConfig {
return domains[domain]
}

func AddDomainForTesting(domain string, config *DomainRepoConfig) {
Get() // Ensure the "main" config was loaded first
if config == nil {
c := NewDefaultDomainConfig()
config = &c
}
domains[domain] = config
}

func DomainConfigFrom(c MainRepoConfig) DomainRepoConfig {
// HACK: We should be better at this kind of inheritance
dc := NewDefaultDomainConfig()
Expand Down
12 changes: 6 additions & 6 deletions matrix/requests_signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,26 @@ type serverKeyResult struct {

type ServerSigningKeys map[string]ed25519.PublicKey

var signingKeySf = new(typedsf.Group[*ServerSigningKeys])
var signingKeySf = new(typedsf.Group[ServerSigningKeys])
var signingKeyCache = cache.New(cache.NoExpiration, 30*time.Second)
var signingKeyRWLock = new(sync.RWMutex)

func querySigningKeyCache(serverName string) *ServerSigningKeys {
func querySigningKeyCache(serverName string) ServerSigningKeys {
if val, ok := signingKeyCache.Get(serverName); ok {
return val.(*ServerSigningKeys)
return val.(ServerSigningKeys)
}
return nil
}

func QuerySigningKeys(serverName string) (*ServerSigningKeys, error) {
func QuerySigningKeys(serverName string) (ServerSigningKeys, error) {
signingKeyRWLock.RLock()
keys := querySigningKeyCache(serverName)
signingKeyRWLock.RUnlock()
if keys != nil {
return keys, nil
}

keys, err, _ := signingKeySf.Do(serverName, func() (*ServerSigningKeys, error) {
keys, err, _ := signingKeySf.Do(serverName, func() (ServerSigningKeys, error) {
ctx := rcontext.Initial().LogWithFields(logrus.Fields{
"keysForServer": serverName,
})
Expand Down Expand Up @@ -144,7 +144,7 @@ func QuerySigningKeys(serverName string) (*ServerSigningKeys, error) {

// Cache & return (unlock was deferred)
signingKeyCache.Set(serverName, &serverKeys, cacheUntil)
return &serverKeys, nil
return serverKeys, nil
})
return keys, err
}
72 changes: 52 additions & 20 deletions matrix/xmatrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"crypto/ed25519"
"errors"
"fmt"
"github.com/turt2live/matrix-media-repo/util"
"net/http"

"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/util"
)

var ErrNoXMatrixAuth = errors.New("no X-Matrix auth headers")
Expand All @@ -15,51 +17,81 @@ func ValidateXMatrixAuth(request *http.Request, expectNoContent bool) (string, e
panic("development error: X-Matrix auth validation can only be done with an empty body for now")
}

auths, err := util.GetXMatrixAuth(request)
auths, err := util.GetXMatrixAuth(request.Header.Values("Authorization"))
if err != nil {
return "", err
}

if len(auths) == 0 {
return "", ErrNoXMatrixAuth
}

obj := map[string]interface{}{
"method": request.Method,
"uri": request.RequestURI,
"origin": auths[0].Origin,
"destination": auths[0].Destination,
"content": "{}",
}
canonical, err := util.EncodeCanonicalJson(obj)
keys, err := QuerySigningKeys(auths[0].Origin)
if err != nil {
return "", err
}

keys, err := QuerySigningKeys(auths[0].Origin)
err = ValidateXMatrixAuthHeader(request.Method, request.RequestURI, &database.AnonymousJson{}, auths, keys)
if err != nil {
return "", err
}
return auths[0].Origin, nil
}

func ValidateXMatrixAuthHeader(requestMethod string, requestUri string, content any, headers []util.XMatrixAuth, originKeys ServerSigningKeys) error {
if len(headers) == 0 {
return ErrNoXMatrixAuth
}

obj := map[string]interface{}{
"method": requestMethod,
"uri": requestUri,
"origin": headers[0].Origin,
"destination": headers[0].Destination,
"content": content,
}
canonical, err := util.EncodeCanonicalJson(obj)
if err != nil {
return err
}

for _, h := range auths {
for _, h := range headers {
if h.Origin != obj["origin"] {
return "", errors.New("auth is from multiple servers")
return errors.New("auth is from multiple servers")
}
if h.Destination != obj["destination"] {
return "", errors.New("auth is for multiple servers")
return errors.New("auth is for multiple servers")
}
if h.Destination != "" && !util.IsServerOurs(h.Destination) {
return "", errors.New("unknown destination")
return errors.New("unknown destination")
}

if key, ok := (*keys)[h.KeyId]; ok {
if key, ok := (originKeys)[h.KeyId]; ok {
if !ed25519.Verify(key, canonical, h.Signature) {
return "", fmt.Errorf("failed signatures on '%s'", h.KeyId)
return fmt.Errorf("failed signatures on '%s'", h.KeyId)
}
} else {
return "", fmt.Errorf("unknown key '%s'", h.KeyId)
return fmt.Errorf("unknown key '%s'", h.KeyId)
}
}

return auths[0].Origin, nil
return nil
}

func CreateXMatrixHeader(origin string, destination string, requestMethod string, requestUri string, content any, key *ed25519.PrivateKey, keyVersion string) (string, error) {
obj := map[string]interface{}{
"method": requestMethod,
"uri": requestUri,
"origin": origin,
"destination": destination,
"content": content,
}
canonical, err := util.EncodeCanonicalJson(obj)
if err != nil {
return "", err
}

b := ed25519.Sign(*key, canonical)
sig := util.EncodeUnpaddedBase64ToString(b)

return fmt.Sprintf("X-Matrix origin=\"%s\",destination=\"%s\",key=\"ed25519:%s\",sig=\"%s\"", origin, destination, keyVersion, sig), nil
}
37 changes: 37 additions & 0 deletions test/xmatrix_header_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package test

import (
"crypto/ed25519"
"testing"

"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/matrix"
"github.com/turt2live/matrix-media-repo/util"
)

func TestXMatrixAuthHeader(t *testing.T) {
config.AddDomainForTesting("localhost", nil)

pub, priv, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatal(err)
}

header, err := matrix.CreateXMatrixHeader("localhost:8008", "localhost", "GET", "/_matrix/media/v3/download/example.org/abc", &database.AnonymousJson{}, &priv, "0")
if err != nil {
t.Fatal(err)
}

auths, err := util.GetXMatrixAuth([]string{header})
if err != nil {
t.Fatal(err)
}

keys := make(matrix.ServerSigningKeys)
keys["ed25519:0"] = pub
err = matrix.ValidateXMatrixAuthHeader("GET", "/_matrix/media/v3/download/example.org/abc", &database.AnonymousJson{}, auths, keys)
if err != nil {
t.Error(err)
}
}
2 changes: 1 addition & 1 deletion util/canonical_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"encoding/json"
)

func EncodeCanonicalJson(obj map[string]interface{}) ([]byte, error) {
func EncodeCanonicalJson(obj any) ([]byte, error) {
b, err := json.Marshal(obj)
if err != nil {
return nil, err
Expand Down
3 changes: 1 addition & 2 deletions util/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ func GetLogSafeUrl(r *http.Request) string {
return copyUrl.String()
}

func GetXMatrixAuth(request *http.Request) ([]XMatrixAuth, error) {
headers := request.Header.Values("Authorization")
func GetXMatrixAuth(headers []string) ([]XMatrixAuth, error) {
auths := make([]XMatrixAuth, 0)
for _, h := range headers {
if !strings.HasPrefix(h, "X-Matrix ") {
Expand Down

0 comments on commit ecddf3f

Please sign in to comment.