Skip to content

Commit

Permalink
fixup! (DO NOT MERGE) crypto: allow run goolm side-by-side with libolm
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Jan 17, 2025
1 parent f4a44df commit f96a85b
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 53 deletions.
3 changes: 2 additions & 1 deletion crypto/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/crypto/libolm"

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (old, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (latest, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (old, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (latest, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id"
Expand All @@ -31,7 +32,7 @@ type OlmAccount struct {
}

func NewOlmAccount() *OlmAccount {
libolmAccount, err := olm.NewAccount()
libolmAccount, err := libolm.NewAccount()
if err != nil {
panic(err)
}
Expand Down
24 changes: 21 additions & 3 deletions crypto/decryptmegolm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package crypto

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -203,7 +204,11 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
firstKnown := sess.InternalLibolm.FirstKnownIndex()
firstKnownGoolm := sess.InternalGoolm.FirstKnownIndex()
if firstKnown != firstKnownGoolm {
panic(fmt.Sprintf("firstKnown not the same %d != %d", firstKnown, firstKnownGoolm))
}
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
Expand All @@ -228,7 +233,16 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
plaintextGoolm, messageIndexGoolm, errGoolm := sess.InternalGoolm.Decrypt(content.MegolmCiphertext)
plaintext, messageIndex, err := sess.InternalLibolm.Decrypt(content.MegolmCiphertext)
if !bytes.Equal(plaintextGoolm, plaintext) {
panic("plaintext different")
} else if messageIndexGoolm != messageIndex {
panic(fmt.Sprintf("message index different %d != %d", messageIndexGoolm, messageIndex))
} else if err != nil && errGoolm == nil {
panic(fmt.Sprintf("goolm didn't error %v", err))
}

if err != nil {
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
Expand Down Expand Up @@ -277,7 +291,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if len(sess.RatchetSafety.MissedIndices) > 0 {
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
}
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
ratchetCurrentIndexGoolm := sess.InternalGoolm.FirstKnownIndex()
ratchetCurrentIndex := sess.InternalLibolm.FirstKnownIndex()
if ratchetCurrentIndexGoolm != ratchetCurrentIndex {
panic(fmt.Sprintf("ratchet current index different %d != %d", ratchetCurrentIndexGoolm, ratchetCurrentIndex))
}
log := zerolog.Ctx(ctx).With().
Uint32("prev_ratchet_index", ratchetCurrentIndex).
Uint32("new_ratchet_index", ratchetTargetIndex).
Expand Down
13 changes: 10 additions & 3 deletions crypto/keybackup.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/libolm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id"
)
Expand Down Expand Up @@ -144,7 +145,12 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm)
}

igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
igsInternalGoolm, err := session.NewMegolmInboundSessionFromExport([]byte(keyBackupData.SessionKey))
if err != nil {
return nil, err
}

igsInternal, err := libolm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
if err != nil {
return nil, fmt.Errorf("failed to import inbound group session: %w", err)
} else if igsInternal.ID() != sessionID {
Expand All @@ -169,7 +175,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
}

igs := &InboundGroupSession{
Internal: igsInternal,
InternalLibolm: igsInternal,
InternalGoolm: igsInternalGoolm,
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,
Expand Down
7 changes: 6 additions & 1 deletion crypto/keyexport.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"math"

"go.mau.fi/util/exerrors"
"go.mau.fi/util/random"
"golang.org/x/crypto/pbkdf2"

Expand Down Expand Up @@ -81,10 +82,14 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte)
func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) {
export := make([]ExportedSession, len(sessions))
for i, session := range sessions {
key, err := session.Internal.Export(session.Internal.FirstKnownIndex())
key, err := session.InternalLibolm.Export(session.InternalLibolm.FirstKnownIndex())
if err != nil {
return nil, fmt.Errorf("failed to export session: %w", err)
}
keyGoolm := exerrors.Must(session.InternalGoolm.Export(session.InternalGoolm.FirstKnownIndex()))
if !bytes.Equal(key, keyGoolm) {
panic("keys not equal")
}
export[i] = ExportedSession{
Algorithm: id.AlgorithmMegolmV1,
ForwardingChains: session.ForwardingChains,
Expand Down
33 changes: 20 additions & 13 deletions crypto/keyimport.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ import (
"fmt"
"time"

"maunium.net/go/mautrix/crypto/olm"
"go.mau.fi/util/exerrors"

"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/libolm"
"maunium.net/go/mautrix/id"
)

Expand Down Expand Up @@ -92,38 +95,42 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
return sessionsJSON, nil
}

func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
if session.Algorithm != id.AlgorithmMegolmV1 {
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, sess ExportedSession) (bool, error) {
if sess.Algorithm != id.AlgorithmMegolmV1 {
return false, ErrInvalidExportedAlgorithm
}

igsInternal, err := olm.InboundGroupSessionImport([]byte(session.SessionKey))
igsInternal, err := libolm.InboundGroupSessionImport([]byte(sess.SessionKey))
if err != nil {
return false, fmt.Errorf("failed to import session: %w", err)
} else if igsInternal.ID() != session.SessionID {
} else if igsInternal.ID() != sess.SessionID {
return false, ErrMismatchingExportedSessionID
}
igs := &InboundGroupSession{
Internal: igsInternal,
SigningKey: session.SenderClaimedKeys.Ed25519,
SenderKey: session.SenderKey,
RoomID: session.RoomID,
InternalLibolm: igsInternal,
InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(sess.SessionKey))),
SigningKey: sess.SenderClaimedKeys.Ed25519,
SenderKey: sess.SenderKey,
RoomID: sess.RoomID,
// TODO should we add something here to mark the signing key as unverified like key requests do?
ForwardingChains: session.ForwardingChains,
ForwardingChains: sess.ForwardingChains,

ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
firstKnownIndex := igs.Internal.FirstKnownIndex()
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
firstKnownIndex := igs.InternalLibolm.FirstKnownIndex()
if firstKnownIndex != igs.InternalGoolm.FirstKnownIndex() {
panic("indexes different")
}
if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= firstKnownIndex {
// We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}
mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex)
mach.markSessionReceived(ctx, sess.RoomID, igs.ID(), firstKnownIndex)
return true, nil
}

Expand Down
22 changes: 18 additions & 4 deletions crypto/keysharing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
package crypto

import (
"bytes"
"context"
"errors"
"time"

"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"

"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"

Expand Down Expand Up @@ -178,7 +181,8 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
igs := &InboundGroupSession{
Internal: igsInternal,
InternalLibolm: igsInternal,
InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(content.SessionKey))),
SigningKey: evt.Keys.Ed25519,
SenderKey: content.SenderKey,
RoomID: content.RoomID,
Expand All @@ -191,7 +195,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
IsScheduled: content.IsScheduled,
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
panic("different indices")
}
if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= igs.InternalLibolm.FirstKnownIndex() {
// We already have an equivalent or better session in the store, so don't override it.
return false
}
Expand Down Expand Up @@ -339,14 +346,21 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
log = log.With().Stringer("unexpected_session_id", internalID).Logger()
}

firstKnownIndex := igs.Internal.FirstKnownIndex()
firstKnownIndex := igs.InternalLibolm.FirstKnownIndex()
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
panic("different indices")
}
log = log.With().Uint32("first_known_index", firstKnownIndex).Logger()
exportedKey, err := igs.Internal.Export(firstKnownIndex)
exportedKey, err := igs.InternalLibolm.Export(firstKnownIndex)
if err != nil {
log.Error().Err(err).Msg("Failed to export group session to forward")
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
return
}
exportedKeyGoolm, err := igs.InternalGoolm.Export(firstKnownIndex)
if !bytes.Equal(exportedKey, exportedKeyGoolm) {
panic("keys different")
}
if igs.ForwardingChains == nil {
igs.ForwardingChains = []string{}
}
Expand Down
5 changes: 4 additions & 1 deletion crypto/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,10 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
panic("different index")
}
mach.markSessionReceived(ctx, roomID, sessionID, igs.InternalLibolm.FirstKnownIndex())
log.Debug().
Str("session_id", sessionID.String()).
Str("sender_key", senderKey.String()).
Expand Down
6 changes: 4 additions & 2 deletions crypto/machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ func TestRatchetMegolmSession(t *testing.T) {
assert.NoError(t, err)
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID())
require.NoError(t, err)
assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex())
assert.Equal(t, uint32(0), inSess.InternalLibolm.FirstKnownIndex())
assert.Equal(t, uint32(0), inSess.InternalGoolm.FirstKnownIndex())
err = inSess.RatchetTo(10)
assert.NoError(t, err)
assert.Equal(t, uint32(10), inSess.Internal.FirstKnownIndex())
assert.Equal(t, uint32(10), inSess.InternalLibolm.FirstKnownIndex())
assert.Equal(t, uint32(10), inSess.InternalGoolm.FirstKnownIndex())
}

func TestOlmMachineOlmMegolmSessions(t *testing.T) {
Expand Down
34 changes: 26 additions & 8 deletions crypto/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ package crypto
import (
"bytes"
"errors"
"fmt"
"time"

"go.mau.fi/util/exerrors"

"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/libolm"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
Expand Down Expand Up @@ -107,7 +110,8 @@ type RatchetSafety struct {
}

type InboundGroupSession struct {
Internal olm.InboundGroupSession
InternalLibolm olm.InboundGroupSession
InternalGoolm olm.InboundGroupSession

SigningKey id.Ed25519
SenderKey id.Curve25519
Expand All @@ -126,12 +130,17 @@ type InboundGroupSession struct {
}

func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) (*InboundGroupSession, error) {
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
igs, err := libolm.NewInboundGroupSession([]byte(sessionKey))
if err != nil {
return nil, err
}
igsGoolm, err := session.NewMegolmInboundSession([]byte(sessionKey))
if err != nil {
return nil, err
}
return &InboundGroupSession{
Internal: igs,
InternalLibolm: igs,
InternalGoolm: igsGoolm,
SigningKey: signingKey,
SenderKey: senderKey,
RoomID: roomID,
Expand All @@ -145,22 +154,31 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI

func (igs *InboundGroupSession) ID() id.SessionID {
if igs.id == "" {
igs.id = igs.Internal.ID()
igs.id = igs.InternalLibolm.ID()
if igs.id != igs.InternalGoolm.ID() {
panic(fmt.Sprintf("id different %s %s", igs.id, igs.InternalGoolm.ID()))
}
}
return igs.id
}

func (igs *InboundGroupSession) RatchetTo(index uint32) error {
exported, err := igs.Internal.Export(index)
exported, err := igs.InternalLibolm.Export(index)
if err != nil {
return err
}
imported, err := olm.InboundGroupSessionImport(exported)
exportedGoolm, err := igs.InternalGoolm.Export(index)
if err != nil {
panic(err)
} else if !bytes.Equal(exported, exportedGoolm) {
panic("bytes not equal")
}
igs.InternalLibolm, err = libolm.InboundGroupSessionImport(exported)
if err != nil {
return err
}
igs.Internal = imported
return nil
igs.InternalGoolm, err = session.NewMegolmInboundSessionFromExport(exportedGoolm)
return err
}

type OGSState int
Expand Down
Loading

0 comments on commit f96a85b

Please sign in to comment.