From 0b0b29e2f15e5c4f35a8e4188225ee3097bf3e1b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?B=C3=A1lint=20Ujv=C3=A1ri?=
 <58116288+bosi95@users.noreply.github.com>
Date: Tue, 26 Mar 2024 10:53:35 +0100
Subject: [PATCH] Session refactor (#24)

* pr comment fix

* add comment to session.NewFromKeystore
---
 pkg/dynamicaccess/session.go      |  9 ++++--
 pkg/dynamicaccess/session_test.go | 51 ++++++++++++++++++++++++++++---
 2 files changed, 53 insertions(+), 7 deletions(-)

diff --git a/pkg/dynamicaccess/session.go b/pkg/dynamicaccess/session.go
index 9d7634ffd01..68cea8ccbfb 100644
--- a/pkg/dynamicaccess/session.go
+++ b/pkg/dynamicaccess/session.go
@@ -21,11 +21,15 @@ type session struct {
 }
 
 func (s *session) Key(publicKey *ecdsa.PublicKey, nonces [][]byte) ([][]byte, error) {
-	x, _ := publicKey.Curve.ScalarMult(publicKey.X, publicKey.Y, s.key.D.Bytes())
-	if x == nil {
+	x, y := publicKey.Curve.ScalarMult(publicKey.X, publicKey.Y, s.key.D.Bytes())
+	if x == nil || y == nil {
 		return nil, errors.New("shared secret is point at infinity")
 	}
 
+	if len(nonces) == 0 {
+		return [][]byte{(*x).Bytes()}, nil
+	}
+
 	keys := make([][]byte, 0, len(nonces))
 	for _, nonce := range nonces {
 		key, err := crypto.LegacyKeccak256(append(x.Bytes(), nonce...))
@@ -44,6 +48,7 @@ func NewDefaultSession(key *ecdsa.PrivateKey) Session {
 	}
 }
 
+// Currently implemented only in mock/session.go
 func NewFromKeystore(ks keystore.Service, tag, password string) Session {
 	return nil
 }
diff --git a/pkg/dynamicaccess/session_test.go b/pkg/dynamicaccess/session_test.go
index 0cfee7691da..501d1abd2b6 100644
--- a/pkg/dynamicaccess/session_test.go
+++ b/pkg/dynamicaccess/session_test.go
@@ -52,9 +52,11 @@ func TestSessionKey(t *testing.T) {
 	}
 	si2 := dynamicaccess.NewDefaultSession(key2)
 
-	nonces := make([][]byte, 1)
-	if _, err := io.ReadFull(rand.Reader, nonces[0]); err != nil {
-		t.Fatal(err)
+	nonces := make([][]byte, 2)
+	for i := range nonces {
+		if _, err := io.ReadFull(rand.Reader, nonces[i]); err != nil {
+			t.Fatal(err)
+		}
 	}
 
 	keys1, err := si1.Key(&key2.PublicKey, nonces)
@@ -66,6 +68,38 @@ func TestSessionKey(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	if !bytes.Equal(keys1[0], keys2[0]) {
+		t.Fatalf("shared secrets do not match %s, %s", hex.EncodeToString(keys1[0]), hex.EncodeToString(keys2[0]))
+	}
+	if !bytes.Equal(keys1[1], keys2[1]) {
+		t.Fatalf("shared secrets do not match %s, %s", hex.EncodeToString(keys1[0]), hex.EncodeToString(keys2[0]))
+	}
+}
+
+func TestSessionKeyWithoutNonces(t *testing.T) {
+	t.Parallel()
+
+	key1, err := crypto.GenerateSecp256k1Key()
+	if err != nil {
+		t.Fatal(err)
+	}
+	si1 := dynamicaccess.NewDefaultSession(key1)
+
+	key2, err := crypto.GenerateSecp256k1Key()
+	if err != nil {
+		t.Fatal(err)
+	}
+	si2 := dynamicaccess.NewDefaultSession(key2)
+
+	keys1, err := si1.Key(&key2.PublicKey, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	keys2, err := si2.Key(&key1.PublicKey, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
 	if !bytes.Equal(keys1[0], keys2[0]) {
 		t.Fatalf("shared secrets do not match %s, %s", hex.EncodeToString(keys1[0]), hex.EncodeToString(keys2[0]))
 	}
@@ -81,6 +115,7 @@ func TestSessionKeyFromKeystore(t *testing.T) {
 	password2 := "password2"
 
 	si1 := mock.NewFromKeystore(ks, tag1, password1, mockKeyFunc)
+	// si1 := dynamicaccess.NewFromKeystore(ks, tag1, password1)
 	exists, err := ks.Exists(tag1)
 	if err != nil {
 		t.Fatal(err)
@@ -97,6 +132,7 @@ func TestSessionKeyFromKeystore(t *testing.T) {
 	}
 
 	si2 := mock.NewFromKeystore(ks, tag2, password2, mockKeyFunc)
+	// si2 := dynamicaccess.NewFromKeystore(ks, tag2, password2)
 	exists, err = ks.Exists(tag2)
 	if err != nil {
 		t.Fatal(err)
@@ -113,8 +149,10 @@ func TestSessionKeyFromKeystore(t *testing.T) {
 	}
 
 	nonces := make([][]byte, 1)
-	if _, err := io.ReadFull(rand.Reader, nonces[0]); err != nil {
-		t.Fatal(err)
+	for i := range nonces {
+		if _, err := io.ReadFull(rand.Reader, nonces[i]); err != nil {
+			t.Fatal(err)
+		}
 	}
 
 	keys1, err := si1.Key(&key2.PublicKey, nonces)
@@ -129,4 +167,7 @@ func TestSessionKeyFromKeystore(t *testing.T) {
 	if !bytes.Equal(keys1[0], keys2[0]) {
 		t.Fatalf("shared secrets do not match %s, %s", hex.EncodeToString(keys1[0]), hex.EncodeToString(keys2[0]))
 	}
+	// if !bytes.Equal(keys1[1], keys2[1]) {
+	// 	t.Fatalf("shared secrets do not match %s, %s", hex.EncodeToString(keys1[0]), hex.EncodeToString(keys2[0]))
+	// }
 }