diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fa6f4c..9f128c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [v0.1.11] +## [v0.1.12] - 2024-07-15 +### Added +- Using Openfort SSS library to split/reconstruct encryption keys. +- Add `shld_shamir_migrations` to manage the migrations of the Shamir secret sharing library. +- Add Migration Jobs to manage the migrations of the Shamir secret sharing library. + + +## [v0.1.11] - 2024-07-11 ### Added - Encryption Sessions, allow projects to register a on time use session with an encryption part to encrypt/decrypt a secret. \ No newline at end of file diff --git a/di/wire.go b/di/wire.go index e87c5f7..8f691e4 100644 --- a/di/wire.go +++ b/di/wire.go @@ -18,6 +18,7 @@ import ( "go.openfort.xyz/shield/internal/adapters/repositories/sql/sharerepo" "go.openfort.xyz/shield/internal/adapters/repositories/sql/userrepo" "go.openfort.xyz/shield/internal/applications/projectapp" + "go.openfort.xyz/shield/internal/applications/shamirjob" "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -137,6 +138,16 @@ func ProvideShareService() (s services.ShareService, err error) { return } +func ProvideShamirJob() (j *shamirjob.Job, err error) { + wire.Build( + shamirjob.New, + ProvideSQLProjectRepository, + ProvideSQLShareRepository, + ) + + return +} + func ProvideShareApplication() (a *shareapp.ShareApplication, err error) { wire.Build( shareapp.New, @@ -144,6 +155,7 @@ func ProvideShareApplication() (a *shareapp.ShareApplication, err error) { ProvideSQLShareRepository, ProvideSQLProjectRepository, ProvideEncryptionFactory, + ProvideShamirJob, ) return diff --git a/di/wire_gen.go b/di/wire_gen.go index 468c3d7..3eea368 100644 --- a/di/wire_gen.go +++ b/di/wire_gen.go @@ -20,6 +20,7 @@ import ( "go.openfort.xyz/shield/internal/adapters/repositories/sql/sharerepo" "go.openfort.xyz/shield/internal/adapters/repositories/sql/userrepo" "go.openfort.xyz/shield/internal/applications/projectapp" + "go.openfort.xyz/shield/internal/applications/shamirjob" "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -150,6 +151,19 @@ func ProvideShareService() (services.ShareService, error) { return shareService, nil } +func ProvideShamirJob() (*shamirjob.Job, error) { + projectRepository, err := ProvideSQLProjectRepository() + if err != nil { + return nil, err + } + shareRepository, err := ProvideSQLShareRepository() + if err != nil { + return nil, err + } + job := shamirjob.New(projectRepository, shareRepository) + return job, nil +} + func ProvideShareApplication() (*shareapp.ShareApplication, error) { shareService, err := ProvideShareService() if err != nil { @@ -167,7 +181,11 @@ func ProvideShareApplication() (*shareapp.ShareApplication, error) { if err != nil { return nil, err } - shareApplication := shareapp.New(shareService, shareRepository, projectRepository, encryptionFactory) + job, err := ProvideShamirJob() + if err != nil { + return nil, err + } + shareApplication := shareapp.New(shareService, shareRepository, projectRepository, encryptionFactory, job) return shareApplication, nil } diff --git a/go.mod b/go.mod index b5c9fb8..15088b7 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/rtred v0.1.2 // indirect github.com/tidwall/tinyqueue v0.1.1 // indirect + go.openfort.xyz/shamir-secret-sharing-go v0.0.1 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index 6a25aba..763090a 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,10 @@ github.com/tidwall/rtred v0.1.2/go.mod h1:hd69WNXQ5RP9vHd7dqekAz+RIdtfBogmglkZSR github.com/tidwall/tinyqueue v0.1.1 h1:SpNEvEggbpyN5DIReaJ2/1ndroY8iyEGxPYxoSaymYE= github.com/tidwall/tinyqueue v0.1.1/go.mod h1:O/QNHwrnjqr6IHItYrzoHAKYhBkLI67Q096fQP5zMYw= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.openfort.xyz/shamir-secret-sharing-go v0.0.0-20240711164902-c1a2df301b09 h1:fPnhsnHpX8xztLlMqXU1xM7ohNkUxW5bCkzk0n/e2c0= +go.openfort.xyz/shamir-secret-sharing-go v0.0.0-20240711164902-c1a2df301b09/go.mod h1:EdgAbmmHezcyCO3ucE+JFMCGwcvV9CSoEsQuHkRM9js= +go.openfort.xyz/shamir-secret-sharing-go v0.0.1 h1:/WeytO+6yXSH9op3zU8Eirxyj8UK/0MbhbS32uy+yYQ= +go.openfort.xyz/shamir-secret-sharing-go v0.0.1/go.mod h1:EdgAbmmHezcyCO3ucE+JFMCGwcvV9CSoEsQuHkRM9js= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= diff --git a/internal/adapters/encryption/deprecated_sss_reconstruction_strategy/strategy.go b/internal/adapters/encryption/deprecated_sss_reconstruction_strategy/strategy.go new file mode 100644 index 0000000..112a4c4 --- /dev/null +++ b/internal/adapters/encryption/deprecated_sss_reconstruction_strategy/strategy.go @@ -0,0 +1,50 @@ +package depsssrec + +import ( + "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/strategies" + "go.openfort.xyz/shield/pkg/cypher" +) + +const ( + MaxReties = 5 +) + +type SSSReconstructionStrategy struct{} + +func NewSSSReconstructionStrategy() strategies.ReconstructionStrategy { + return &SSSReconstructionStrategy{} +} + +func (s *SSSReconstructionStrategy) Split(data string) (storedPart string, projectPart string, err error) { + for i := 0; i < MaxReties; i++ { + storedPart, projectPart, err = cypher.SplitEncryptionKey(data) + if err != nil { + continue + } + + err = s.validateSplit(data, storedPart, projectPart) + if err == nil { + return + } + } + + return +} + +func (s *SSSReconstructionStrategy) Reconstruct(storedPart string, projectPart string) (string, error) { + return cypher.ReconstructEncryptionKey(storedPart, projectPart) +} + +func (s *SSSReconstructionStrategy) validateSplit(data string, storedPart string, projectPart string) error { + reconstructed, err := s.Reconstruct(storedPart, projectPart) + if err != nil { + return err + } + + if data != reconstructed { + return errors.ErrReconstructedKeyMismatch + } + + return nil +} diff --git a/internal/adapters/encryption/factory.go b/internal/adapters/encryption/factory.go index 1c4afa6..1782d0b 100644 --- a/internal/adapters/encryption/factory.go +++ b/internal/adapters/encryption/factory.go @@ -2,6 +2,7 @@ package encryption import ( aesencryptionstrategy "go.openfort.xyz/shield/internal/adapters/encryption/aes_encryption_strategy" + depsssrec "go.openfort.xyz/shield/internal/adapters/encryption/deprecated_sss_reconstruction_strategy" plnbldr "go.openfort.xyz/shield/internal/adapters/encryption/plain_builder" sessbldr "go.openfort.xyz/shield/internal/adapters/encryption/session_builder" sssrec "go.openfort.xyz/shield/internal/adapters/encryption/sss_reconstruction_strategy" @@ -24,19 +25,28 @@ func NewEncryptionFactory(encryptionPartsRepo repositories.EncryptionPartsReposi } } -func (e *encryptionFactory) CreateEncryptionKeyBuilder(builderType factories.EncryptionKeyBuilderType) (builders.EncryptionKeyBuilder, error) { +func (e *encryptionFactory) CreateEncryptionKeyBuilder(builderType factories.EncryptionKeyBuilderType, projectMigrated bool) (builders.EncryptionKeyBuilder, error) { + var reconstructionStrategy strategies.ReconstructionStrategy + if projectMigrated { + reconstructionStrategy = sssrec.NewSSSReconstructionStrategy() + } else { + reconstructionStrategy = depsssrec.NewSSSReconstructionStrategy() + } switch builderType { case factories.Plain: - return plnbldr.NewEncryptionKeyBuilder(e.projectRepo, sssrec.NewSSSReconstructionStrategy()), nil + return plnbldr.NewEncryptionKeyBuilder(e.projectRepo, reconstructionStrategy), nil case factories.Session: - return sessbldr.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo, sssrec.NewSSSReconstructionStrategy()), nil + return sessbldr.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo, reconstructionStrategy), nil } return nil, errors.ErrInvalidEncryptionKeyBuilderType } -func (e *encryptionFactory) CreateReconstructionStrategy() strategies.ReconstructionStrategy { - return sssrec.NewSSSReconstructionStrategy() +func (e *encryptionFactory) CreateReconstructionStrategy(projectMigrated bool) strategies.ReconstructionStrategy { + if projectMigrated { + return sssrec.NewSSSReconstructionStrategy() + } + return depsssrec.NewSSSReconstructionStrategy() } func (e *encryptionFactory) CreateEncryptionStrategy(key string) strategies.EncryptionStrategy { diff --git a/internal/adapters/encryption/plain_builder/builder.go b/internal/adapters/encryption/plain_builder/builder.go index 1dfed9c..e570437 100644 --- a/internal/adapters/encryption/plain_builder/builder.go +++ b/internal/adapters/encryption/plain_builder/builder.go @@ -39,6 +39,14 @@ func (b *plainBuilder) SetDatabasePart(ctx context.Context, identifier string) e return nil } +func (b *plainBuilder) GetProjectPart(_ context.Context) string { + return b.projectPart +} + +func (b *plainBuilder) GetDatabasePart(_ context.Context) string { + return b.databasePart +} + func (b *plainBuilder) Build(_ context.Context) (string, error) { if b.projectPart == "" { return "", domainErrors.ErrProjectPartRequired diff --git a/internal/adapters/encryption/session_builder/builder.go b/internal/adapters/encryption/session_builder/builder.go index d6fd856..dcbf1bb 100644 --- a/internal/adapters/encryption/session_builder/builder.go +++ b/internal/adapters/encryption/session_builder/builder.go @@ -54,6 +54,14 @@ func (b *sessionBuilder) SetDatabasePart(ctx context.Context, identifier string) return nil } +func (b *sessionBuilder) GetProjectPart(_ context.Context) string { + return b.projectPart +} + +func (b *sessionBuilder) GetDatabasePart(_ context.Context) string { + return b.databasePart +} + func (b *sessionBuilder) Build(_ context.Context) (string, error) { if b.projectPart == "" { return "", domainErrors.ErrProjectPartRequired diff --git a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go index b058a0e..5c67b9a 100644 --- a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go +++ b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go @@ -1,9 +1,11 @@ package sssrec import ( + "encoding/base64" + + sss "go.openfort.xyz/shamir-secret-sharing-go" "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/strategies" - "go.openfort.xyz/shield/pkg/cypher" ) const ( @@ -16,24 +18,60 @@ func NewSSSReconstructionStrategy() strategies.ReconstructionStrategy { return &SSSReconstructionStrategy{} } -func (s *SSSReconstructionStrategy) Split(data string) (storedPart string, projectPart string, err error) { +func (s *SSSReconstructionStrategy) Split(data string) (string, string, error) { for i := 0; i < MaxReties; i++ { - storedPart, projectPart, err = cypher.SplitEncryptionKey(data) + rawKey, err := base64.StdEncoding.DecodeString(data) if err != nil { - continue + return "", "", err + } + + parts, err := sss.Split(2, 2, rawKey) + if err != nil { + return "", "", err + } + + if len(parts) != 2 { + return "", "", errors.ErrFailedToSplitKey } + storedPart := base64.StdEncoding.EncodeToString(parts[0]) + projectPart := base64.StdEncoding.EncodeToString(parts[1]) + err = s.validateSplit(data, storedPart, projectPart) if err == nil { - return + return storedPart, projectPart, nil } } - return + return "", "", errors.ErrFailedToSplitKey } func (s *SSSReconstructionStrategy) Reconstruct(storedPart string, projectPart string) (string, error) { - return cypher.ReconstructEncryptionKey(storedPart, projectPart) + rawStoredPart, err := base64.StdEncoding.DecodeString(storedPart) + if err != nil { + return "", err + } + + // Backward compatibility with old keys + if len(rawStoredPart) == 32 { + rawStoredPart = append([]byte{1}, rawStoredPart...) + } + + rawProjectPart, err := base64.StdEncoding.DecodeString(projectPart) + if err != nil { + return "", err + } + // Backward compatibility with old keys + if len(rawProjectPart) == 32 { + rawProjectPart = append([]byte{2}, rawProjectPart...) + } + + combined, err := sss.Combine([][]byte{rawStoredPart, rawProjectPart}) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(combined), nil } func (s *SSSReconstructionStrategy) validateSplit(data string, storedPart string, projectPart string) error { diff --git a/internal/adapters/repositories/mocks/projectmockrepo/repo.go b/internal/adapters/repositories/mocks/projectmockrepo/repo.go index 8dfe9e6..5ab6c43 100644 --- a/internal/adapters/repositories/mocks/projectmockrepo/repo.go +++ b/internal/adapters/repositories/mocks/projectmockrepo/repo.go @@ -52,3 +52,13 @@ func (m *MockProjectRepository) SetEncryptionPart(ctx context.Context, projectID args := m.Mock.Called(ctx, projectID, part) return args.Error(0) } + +func (m *MockProjectRepository) CreateMigration(ctx context.Context, projectID string, success bool) error { + args := m.Mock.Called(ctx, projectID, success) + return args.Error(0) +} + +func (m *MockProjectRepository) HasSuccessfulMigration(ctx context.Context, projectID string) (bool, error) { + args := m.Mock.Called(ctx, projectID) + return args.Bool(0), args.Error(1) +} diff --git a/internal/adapters/repositories/mocks/sharemockrepo/repo.go b/internal/adapters/repositories/mocks/sharemockrepo/repo.go index 3db6ec7..8d5c91b 100644 --- a/internal/adapters/repositories/mocks/sharemockrepo/repo.go +++ b/internal/adapters/repositories/mocks/sharemockrepo/repo.go @@ -32,8 +32,8 @@ func (m *MockShareRepository) Delete(ctx context.Context, shareID string) error return args.Error(0) } -func (m *MockShareRepository) ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) { - args := m.Mock.Called(ctx, projectID) +func (m *MockShareRepository) ListProjectIDAndEntropy(ctx context.Context, projectID string, entropy share.Entropy) ([]*share.Share, error) { + args := m.Mock.Called(ctx, projectID, entropy) if args.Get(0) == nil { return nil, args.Error(1) } @@ -49,3 +49,8 @@ func (m *MockShareRepository) Update(ctx context.Context, shr *share.Share) erro args := m.Mock.Called(ctx, shr) return args.Error(0) } + +func (m *MockShareRepository) BulkUpdate(ctx context.Context, shrs []*share.Share) error { + args := m.Mock.Called(ctx, shrs) + return args.Error(0) +} diff --git a/internal/adapters/repositories/sql/migrations/20240712130440_shamir_migrations.sql b/internal/adapters/repositories/sql/migrations/20240712130440_shamir_migrations.sql new file mode 100644 index 0000000..d429758 --- /dev/null +++ b/internal/adapters/repositories/sql/migrations/20240712130440_shamir_migrations.sql @@ -0,0 +1,26 @@ +-- +goose Up +CREATE TABLE IF NOT EXISTS shld_shamir_migrations ( + id VARCHAR(36) PRIMARY KEY, + project_id VARCHAR(36) NOT NULL, + timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + success BOOLEAN NOT NULL DEFAULT FALSE +); +CREATE INDEX idx_project_id_success ON shld_shamir_migrations(project_id, success); +DROP TABLE IF EXISTS shld_allowed_origins; +-- +goose StatementBegin +-- +goose StatementEnd + +-- +goose Down +DROP INDEX idx_project_id_success ON shld_shamir_migrations; +DROP TABLE IF EXISTS shld_shamir_migrations; +CREATE TABLE IF NOT EXISTS shld_allowed_origins ( + id VARCHAR(36) PRIMARY KEY, + project_id VARCHAR(36) NOT NULL, + origin VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + deleted_at TIMESTAMP DEFAULT NULL +); +ALTER TABLE shld_allowed_origins ADD CONSTRAINT fk_origin_project FOREIGN KEY (project_id) REFERENCES shld_projects(id) ON DELETE CASCADE; +-- +goose StatementBegin +-- +goose StatementEnd diff --git a/internal/adapters/repositories/sql/projectrepo/repo.go b/internal/adapters/repositories/sql/projectrepo/repo.go index 45499d2..67a3a64 100644 --- a/internal/adapters/repositories/sql/projectrepo/repo.go +++ b/internal/adapters/repositories/sql/projectrepo/repo.go @@ -125,3 +125,34 @@ func (r *repository) SetEncryptionPart(ctx context.Context, projectID, part stri return nil } + +func (r *repository) CreateMigration(ctx context.Context, projectID string, success bool) error { + r.logger.InfoContext(ctx, "creating migration", slog.String("project_id", projectID), slog.Bool("success", success)) + + migration := &Migration{ + ID: uuid.NewString(), + ProjectID: projectID, + Success: success, + } + + err := r.db.Create(migration).Error + if err != nil { + r.logger.ErrorContext(ctx, "error creating migration", logger.Error(err)) + return err + } + + return nil +} + +func (r *repository) HasSuccessfulMigration(ctx context.Context, projectID string) (bool, error) { + r.logger.InfoContext(ctx, "checking for successful migration", slog.String("project_id", projectID)) + + var count int64 + err := r.db.Model(&Migration{}).Where("project_id = ? AND success = ?", projectID, true).Count(&count).Error + if err != nil { + r.logger.ErrorContext(ctx, "error checking for successful migration", logger.Error(err)) + return false, err + } + + return count > 0, nil +} diff --git a/internal/adapters/repositories/sql/projectrepo/types.go b/internal/adapters/repositories/sql/projectrepo/types.go index a828030..c699607 100644 --- a/internal/adapters/repositories/sql/projectrepo/types.go +++ b/internal/adapters/repositories/sql/projectrepo/types.go @@ -29,3 +29,14 @@ type EncryptionPart struct { func (EncryptionPart) TableName() string { return "shld_encryption_parts" } + +type Migration struct { + ID string `gorm:"column:id;primaryKey"` + ProjectID string `gorm:"column:project_id"` + Timestamp time.Time `gorm:"column:timestamp;autoCreateTime"` + Success bool `gorm:"column:success"` +} + +func (Migration) TableName() string { + return "shld_shamir_migrations" +} diff --git a/internal/adapters/repositories/sql/sharerepo/repo.go b/internal/adapters/repositories/sql/sharerepo/repo.go index d5d5877..eef517f 100644 --- a/internal/adapters/repositories/sql/sharerepo/repo.go +++ b/internal/adapters/repositories/sql/sharerepo/repo.go @@ -76,14 +76,14 @@ func (r *repository) Delete(ctx context.Context, shareID string) error { return nil } -func (r *repository) ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) { +func (r *repository) ListProjectIDAndEntropy(ctx context.Context, projectID string, entropy share.Entropy) ([]*share.Share, error) { r.logger.InfoContext(ctx, "listing shares", slog.String("project_id", projectID)) var dbShares []*Share err := r.db.Joins("JOIN shld_users ON shld_shares.user_id = shld_users.id"). Joins("JOIN shld_projects ON shld_users.project_id = shld_projects.id"). Where("shld_projects.id = ?", projectID). - Where("shld_shares.entropy = ?", EntropyNone). + Where("shld_shares.entropy = ?", r.parser.mapDomainEntropy[entropy]). Find(&dbShares).Error if err != nil { r.logger.ErrorContext(ctx, "error listing shares", logger.Error(err)) @@ -122,3 +122,25 @@ func (r *repository) Update(ctx context.Context, shr *share.Share) error { return nil } + +func (r *repository) BulkUpdate(ctx context.Context, shrs []*share.Share) error { + r.logger.InfoContext(ctx, "bulk updating shares") + + var dbShares []*Share + for _, shr := range shrs { + dbShares = append(dbShares, r.parser.toDatabase(shr)) + } + + return r.db.Transaction(func(tx *gorm.DB) error { + for _, dbShr := range dbShares { + err := tx.Save(dbShr).Error + if err != nil { + r.logger.ErrorContext(ctx, "error updating share", logger.Error(err)) + return err + } + } + + r.logger.InfoContext(ctx, "bulk updated shares", slog.Int("count", len(shrs))) + return nil + }) +} diff --git a/internal/applications/projectapp/app.go b/internal/applications/projectapp/app.go index e5b684b..1cf8056 100644 --- a/internal/applications/projectapp/app.go +++ b/internal/applications/projectapp/app.go @@ -60,12 +60,12 @@ func (a *ProjectApplication) CreateProject(ctx context.Context, name string, opt if o.generateEncryptionKey { part, err := a.registerEncryptionKey(ctx, proj.ID) if err != nil { + a.logger.ErrorContext(ctx, "failed to register encryption key", logger.Error(err)) errD := a.projectRepo.Delete(ctx, proj.ID) if errD != nil { a.logger.Error("failed to delete project", logger.Error(errD)) err = errors.Join(err, errD) } - a.logger.ErrorContext(ctx, "failed to register encryption key", logger.Error(err)) return nil, fromDomainError(err) } @@ -274,7 +274,13 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP a.logger.InfoContext(ctx, "encrypting project shares") projectID := contexter.GetProjectID(ctx) - builder, err := a.encryptionFactory.CreateEncryptionKeyBuilder(factories.Plain) + isMigrated, err := a.projectRepo.HasSuccessfulMigration(ctx, projectID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to check migration", logger.Error(err)) + return ErrInternal + } + + builder, err := a.encryptionFactory.CreateEncryptionKeyBuilder(factories.Plain, isMigrated) if err != nil { a.logger.ErrorContext(ctx, "failed to create encryption key builder", logger.Error(err)) return ErrInternal @@ -298,7 +304,7 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP return ErrInvalidEncryptionPart } - shares, err := a.sharesRepo.ListDecryptedByProjectID(ctx, projectID) + shares, err := a.sharesRepo.ListProjectIDAndEntropy(ctx, projectID, share.EntropyNone) if err != nil { a.logger.ErrorContext(ctx, "failed to list shares", logger.Error(err)) return fromDomainError(err) @@ -377,7 +383,7 @@ func (a *ProjectApplication) registerEncryptionKey(ctx context.Context, projectI return "", ErrInternal } - reconstructionStrategy := a.encryptionFactory.CreateReconstructionStrategy() + reconstructionStrategy := a.encryptionFactory.CreateReconstructionStrategy(true) storedPart, projectPart, err := reconstructionStrategy.Split(key) if err != nil { a.logger.Error("failed to split encryption key", logger.Error(err)) @@ -389,5 +395,11 @@ func (a *ProjectApplication) registerEncryptionKey(ctx context.Context, projectI return "", err } + err = a.projectRepo.CreateMigration(ctx, projectID, true) + if err != nil { + a.logger.Error("failed to create migration", logger.Error(err)) + return "", ErrInternal + } + return projectPart, nil } diff --git a/internal/applications/projectapp/app_test.go b/internal/applications/projectapp/app_test.go index ea2f912..fb8bdab 100644 --- a/internal/applications/projectapp/app_test.go +++ b/internal/applications/projectapp/app_test.go @@ -67,6 +67,7 @@ func TestProjectApplication_CreateProject(t *testing.T) { projectRepo.On("Create", mock.Anything, mock.AnythingOfType("*project.Project")).Return(nil) projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", nil) projectRepo.On("SetEncryptionPart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + projectRepo.On("CreateMigration", mock.Anything, mock.Anything, mock.Anything).Return(nil) }, }, { @@ -844,7 +845,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { t.Fatalf(key) } - reconstructor := encryptionFactory.CreateReconstructionStrategy() + reconstructor := encryptionFactory.CreateReconstructionStrategy(true) storedPart, projectPart, err := reconstructor.Split(key) if err != nil { t.Fatalf("failed to generate encryption key: %v", err) @@ -889,10 +890,11 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil shareRepo.ExpectedCalls = nil - shareRepo.On("ListDecryptedByProjectID", mock.Anything, mock.Anything).Return([]*share.Share{plainShare, encryptedShare}, nil) + shareRepo.On("ListProjectIDAndEntropy", mock.Anything, mock.Anything, mock.Anything).Return([]*share.Share{plainShare, encryptedShare}, nil) shareRepo.On("Update", mock.Anything, mock.Anything).Return(nil) projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return(storedPart, nil) shareRepo.On("UpdateProjectEncryption", mock.Anything, mock.Anything, mock.Anything).Return(nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, mock.Anything).Return(true, nil) }, }, { @@ -901,6 +903,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", domainErrors.ErrEncryptionPartNotFound) + projectRepo.On("HasSuccessfulMigration", mock.Anything, mock.Anything).Return(true, nil) }, wantErr: ErrEncryptionNotConfigured, }, @@ -910,6 +913,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", errors.New("repository error")) + projectRepo.On("HasSuccessfulMigration", mock.Anything, mock.Anything).Return(true, nil) }, wantErr: ErrInternal, }, @@ -919,6 +923,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("invalid", nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, mock.Anything).Return(true, nil) }, wantErr: ErrInvalidEncryptionPart, }, @@ -929,7 +934,8 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return(storedPart, nil) - shareRepo.On("ListDecryptedByProjectID", mock.Anything, mock.Anything).Return(nil, errors.New("repository error")) + shareRepo.On("ListProjectIDAndEntropy", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("repository error")) + projectRepo.On("HasSuccessfulMigration", mock.Anything, mock.Anything).Return(true, nil) }, wantErr: ErrInternal, }, @@ -940,8 +946,9 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return(storedPart, nil) shareRepo.ExpectedCalls = nil - shareRepo.On("ListDecryptedByProjectID", mock.Anything, mock.Anything).Return([]*share.Share{plainShare2}, nil) + shareRepo.On("ListProjectIDAndEntropy", mock.Anything, mock.Anything, mock.Anything).Return([]*share.Share{plainShare2}, nil) shareRepo.On("UpdateProjectEncryption", mock.Anything, "share_id", mock.Anything).Return(errors.New("repository error")) + projectRepo.On("HasSuccessfulMigration", mock.Anything, mock.Anything).Return(true, nil) }, wantErr: ErrInternal, }, @@ -981,6 +988,7 @@ func TestProjectApplication_RegisterEncryptionKey(t *testing.T) { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) projectRepo.On("SetEncryptionPart", mock.Anything, "project_id", mock.Anything).Return(nil) + projectRepo.On("CreateMigration", mock.Anything, mock.Anything, mock.Anything).Return(nil) }, }, { diff --git a/internal/applications/shamirjob/job.go b/internal/applications/shamirjob/job.go new file mode 100644 index 0000000..6345498 --- /dev/null +++ b/internal/applications/shamirjob/job.go @@ -0,0 +1,100 @@ +package shamirjob + +import ( + "context" + "log/slog" + "sync" + + aesenc "go.openfort.xyz/shield/internal/adapters/encryption/aes_encryption_strategy" + sssrec "go.openfort.xyz/shield/internal/adapters/encryption/sss_reconstruction_strategy" + "go.openfort.xyz/shield/internal/core/domain/share" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/internal/core/ports/strategies" + "go.openfort.xyz/shield/pkg/logger" +) + +type Job struct { + projectRepo repositories.ProjectRepository + shareRepo repositories.ShareRepository + reconstructionStrategy strategies.ReconstructionStrategy + logger *slog.Logger + mu sync.Mutex +} + +func New(projectRepo repositories.ProjectRepository, shareRepo repositories.ShareRepository) *Job { + return &Job{ + projectRepo: projectRepo, + shareRepo: shareRepo, + reconstructionStrategy: sssrec.NewSSSReconstructionStrategy(), + logger: logger.New("shamirjob"), + } +} + +func (j *Job) Execute(ctx context.Context, projectID string, storedPart, projectPart, key string) (err error) { + j.mu.Lock() + defer j.mu.Unlock() + + j.logger.InfoContext(ctx, "executing job", slog.String("project_id", projectID)) + + isMigrated, err := j.projectRepo.HasSuccessfulMigration(ctx, projectID) + if err != nil { + j.logger.ErrorContext(ctx, "error checking migration", logger.Error(err)) + return err + } + + if isMigrated { + j.logger.InfoContext(ctx, "project already migrated") + return nil + } + + defer func() { + success := err == nil + err = j.projectRepo.CreateMigration(ctx, projectID, success) + if err != nil { + j.logger.ErrorContext(ctx, "error creating migration", logger.Error(err)) + } + }() + + j.logger.InfoContext(ctx, "reconstructing key") + reconstructedKey, err := j.reconstructionStrategy.Reconstruct(storedPart, projectPart) + if err != nil { + j.logger.ErrorContext(ctx, "error reconstructing key", logger.Error(err)) + return err + } + + decryptStrategy := aesenc.NewAESEncryptionStrategy(key) + encryptStrategy := aesenc.NewAESEncryptionStrategy(reconstructedKey) + + j.logger.InfoContext(ctx, "loading shares") + shares, err := j.shareRepo.ListProjectIDAndEntropy(ctx, projectID, share.EntropyProject) + if err != nil { + return err + } + j.logger.InfoContext(ctx, "loaded shares", slog.Int("count", len(shares))) + + j.logger.InfoContext(ctx, "re-encrypting shares") + for _, shr := range shares { + decr, err := decryptStrategy.Decrypt(shr.Secret) + if err != nil { + j.logger.ErrorContext(ctx, "error decrypting", logger.Error(err)) + return err + } + + encr, err := encryptStrategy.Encrypt(decr) + if err != nil { + j.logger.ErrorContext(ctx, "error encrypting", logger.Error(err)) + return err + } + + shr.Secret = encr + } + + j.logger.InfoContext(ctx, "updating shares") + err = j.shareRepo.BulkUpdate(ctx, shares) + if err != nil { + j.logger.ErrorContext(ctx, "error updating shares", logger.Error(err)) + return err + } + + return nil +} diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go index 98ab5e9..d03a814 100644 --- a/internal/applications/shareapp/app.go +++ b/internal/applications/shareapp/app.go @@ -4,6 +4,8 @@ import ( "context" "log/slog" + "go.openfort.xyz/shield/internal/applications/shamirjob" + "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/domain/share" @@ -19,15 +21,17 @@ type ShareApplication struct { projectRepo repositories.ProjectRepository logger *slog.Logger encryptionFactory factories.EncryptionFactory + shamirJob *shamirjob.Job } -func New(shareSvc services.ShareService, shareRepo repositories.ShareRepository, projectRepo repositories.ProjectRepository, encryptionFactory factories.EncryptionFactory) *ShareApplication { +func New(shareSvc services.ShareService, shareRepo repositories.ShareRepository, projectRepo repositories.ProjectRepository, encryptionFactory factories.EncryptionFactory, shamirJob *shamirjob.Job) *ShareApplication { return &ShareApplication{ shareSvc: shareSvc, shareRepo: shareRepo, projectRepo: projectRepo, logger: logger.New("share_application"), encryptionFactory: encryptionFactory, + shamirJob: shamirJob, } } @@ -184,7 +188,13 @@ func (a *ShareApplication) reconstructEncryptionKey(ctx context.Context, projID return "", ErrEncryptionPartRequired } - builder, err := a.encryptionFactory.CreateEncryptionKeyBuilder(builderType) + isMigrated, err := a.projectRepo.HasSuccessfulMigration(ctx, projID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to check migration", logger.Error(err)) + return "", ErrInternal + } + + builder, err := a.encryptionFactory.CreateEncryptionKeyBuilder(builderType, isMigrated) if err != nil { a.logger.ErrorContext(ctx, "failed to create encryption key builder", logger.Error(err)) return "", ErrInternal @@ -208,5 +218,15 @@ func (a *ShareApplication) reconstructEncryptionKey(ctx context.Context, projID return "", ErrInvalidEncryptionPart } + if !isMigrated { + ctxWithoutCancel := context.WithoutCancel(ctx) + go func() { + err = a.shamirJob.Execute(ctxWithoutCancel, projID, builder.GetDatabasePart(ctxWithoutCancel), builder.GetProjectPart(ctxWithoutCancel), encryptionKey) + if err != nil { + a.logger.ErrorContext(ctx, "failed to execute shamir job", logger.Error(err)) + } + }() + } + return encryptionKey, nil } diff --git a/internal/applications/shareapp/app_test.go b/internal/applications/shareapp/app_test.go index ea6f2cc..57650d4 100644 --- a/internal/applications/shareapp/app_test.go +++ b/internal/applications/shareapp/app_test.go @@ -9,6 +9,7 @@ import ( "go.openfort.xyz/shield/internal/adapters/repositories/mocks/encryptionpartsmockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/sharemockrepo" + "go.openfort.xyz/shield/internal/applications/shamirjob" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/services/sharesvc" @@ -25,13 +26,13 @@ func TestShareApplication_GetShare(t *testing.T) { encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) shareSvc := sharesvc.New(shareRepo, encryptionFactory) - app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory, &shamirjob.Job{}) key, err := random.GenerateRandomString(32) if err != nil { t.Fatalf(key) } - reconstructor := encryptionFactory.CreateReconstructionStrategy() + reconstructor := encryptionFactory.CreateReconstructionStrategy(true) storedPart, projectPart, err := reconstructor.Split(key) if err != nil { t.Fatalf("failed to generate encryption key: %v", err) @@ -83,6 +84,7 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart(projectPart), @@ -101,6 +103,8 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) encryptionPartsRepo.On("Get", mock.Anything, "sessionID").Return(projectPart, nil) encryptionPartsRepo.On("Delete", mock.Anything, "sessionID").Return(nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) + }, opts: []Option{ WithEncryptionSession("sessionID"), @@ -125,6 +129,8 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) + }, opts: []Option{ WithEncryptionPart(projectPart), @@ -139,6 +145,7 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart("invalid-key"), @@ -152,6 +159,7 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(decryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart(projectPart), @@ -184,6 +192,7 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", errors.New("repository error")) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart(projectPart), @@ -198,6 +207,7 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart(projectPart), @@ -224,13 +234,13 @@ func TestShareApplication_RegisterShare(t *testing.T) { encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) shareSvc := sharesvc.New(shareRepo, encryptionFactory) - app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory, &shamirjob.Job{}) key, err := random.GenerateRandomString(32) if err != nil { t.Fatalf(key) } - storedPart, projectPart, err := encryptionFactory.CreateReconstructionStrategy().Split(key) + storedPart, projectPart, err := encryptionFactory.CreateReconstructionStrategy(true).Split(key) if err != nil { t.Fatalf("failed to generate encryption key: %v", err) } @@ -279,6 +289,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) shareRepo.On("Create", mock.Anything, encryptedShare).Return(nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart(projectPart), @@ -302,6 +313,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart(projectPart), @@ -316,6 +328,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) + projectRepo.On("HasSuccessfulMigration", mock.Anything, "project_id").Return(true, nil) }, opts: []Option{ WithEncryptionPart("invalid-key"), @@ -359,7 +372,7 @@ func TestShareApplication_DeleteShare(t *testing.T) { encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) shareSvc := sharesvc.New(shareRepo, encryptionFactory) - app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory, &shamirjob.Job{}) tc := []struct { name string @@ -420,7 +433,7 @@ func TestShareApplication_UpdateShare(t *testing.T) { encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) shareSvc := sharesvc.New(shareRepo, encryptionFactory) - app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory, &shamirjob.Job{}) updates := &share.Share{ ID: "share-id", Secret: "secret", diff --git a/internal/core/domain/errors/project.go b/internal/core/domain/errors/project.go index 7d7c4d0..6416ab3 100644 --- a/internal/core/domain/errors/project.go +++ b/internal/core/domain/errors/project.go @@ -12,4 +12,5 @@ var ( ErrReconstructedKeyMismatch = errors.New("reconstructed key mismatch") ErrProjectPartRequired = errors.New("project part is required") ErrDatabasePartRequired = errors.New("database part is required") + ErrFailedToSplitKey = errors.New("failed to split key") ) diff --git a/internal/core/ports/builders/encryption.go b/internal/core/ports/builders/encryption.go index 02abd77..cdb6307 100644 --- a/internal/core/ports/builders/encryption.go +++ b/internal/core/ports/builders/encryption.go @@ -7,5 +7,9 @@ import ( type EncryptionKeyBuilder interface { SetProjectPart(ctx context.Context, identifier string) error SetDatabasePart(ctx context.Context, identifier string) error + + GetProjectPart(ctx context.Context) string + GetDatabasePart(ctx context.Context) string + Build(ctx context.Context) (string, error) } diff --git a/internal/core/ports/factories/encryption.go b/internal/core/ports/factories/encryption.go index e4ba092..1dda7f9 100644 --- a/internal/core/ports/factories/encryption.go +++ b/internal/core/ports/factories/encryption.go @@ -6,8 +6,8 @@ import ( ) type EncryptionFactory interface { - CreateEncryptionKeyBuilder(builderType EncryptionKeyBuilderType) (builders.EncryptionKeyBuilder, error) - CreateReconstructionStrategy() strategies.ReconstructionStrategy + CreateEncryptionKeyBuilder(builderType EncryptionKeyBuilderType, projectMigrated bool) (builders.EncryptionKeyBuilder, error) + CreateReconstructionStrategy(projectMigrated bool) strategies.ReconstructionStrategy CreateEncryptionStrategy(key string) strategies.EncryptionStrategy } diff --git a/internal/core/ports/repositories/project.go b/internal/core/ports/repositories/project.go index e7246d7..30b909e 100644 --- a/internal/core/ports/repositories/project.go +++ b/internal/core/ports/repositories/project.go @@ -14,4 +14,7 @@ type ProjectRepository interface { GetEncryptionPart(ctx context.Context, projectID string) (string, error) SetEncryptionPart(ctx context.Context, projectID, part string) error + + CreateMigration(ctx context.Context, projectID string, success bool) error + HasSuccessfulMigration(ctx context.Context, projectID string) (bool, error) } diff --git a/internal/core/ports/repositories/shares.go b/internal/core/ports/repositories/shares.go index 4f681b7..52873d4 100644 --- a/internal/core/ports/repositories/shares.go +++ b/internal/core/ports/repositories/shares.go @@ -10,7 +10,8 @@ type ShareRepository interface { Create(ctx context.Context, shr *share.Share) error GetByUserID(ctx context.Context, userID string) (*share.Share, error) Delete(ctx context.Context, shareID string) error - ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) + ListProjectIDAndEntropy(ctx context.Context, projectID string, entropy share.Entropy) ([]*share.Share, error) UpdateProjectEncryption(ctx context.Context, shareID string, encrypted string) error Update(ctx context.Context, shr *share.Share) error + BulkUpdate(ctx context.Context, shrs []*share.Share) error } diff --git a/internal/core/services/sharesvc/svc_test.go b/internal/core/services/sharesvc/svc_test.go index e219ca5..806b57a 100644 --- a/internal/core/services/sharesvc/svc_test.go +++ b/internal/core/services/sharesvc/svc_test.go @@ -8,7 +8,6 @@ import ( "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/cypher" "go.openfort.xyz/shield/pkg/random" "testing" @@ -40,13 +39,13 @@ func TestCreateShare(t *testing.T) { t.Fatalf(key) } - reconstructor := encryptionFactory.CreateReconstructionStrategy() + reconstructor := encryptionFactory.CreateReconstructionStrategy(true) storedPart, projectPart, err := reconstructor.Split(key) if err != nil { t.Fatalf("failed to generate encryption key: %v", err) } - encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, projectPart) + encryptionKey, err := reconstructor.Reconstruct(storedPart, projectPart) if err != nil { t.Fatalf("failed to reconstruct encryption key: %v", err) } diff --git a/pkg/cypher/cypher.go b/pkg/cypher/cypher.go index c78e0cb..e120485 100644 --- a/pkg/cypher/cypher.go +++ b/pkg/cypher/cypher.go @@ -103,9 +103,9 @@ func ReconstructEncryptionKey(part1, part2 string) (string, error) { return "", err } - subset := make(map[byte][]byte) - subset[1] = rawPart1 - subset[2] = rawPart2 + subset := make(map[byte][]byte, 2) + subset[0] = rawPart1 + subset[1] = rawPart2 key := sss.Combine(subset)