From 046a10630b5e18cda3f9087ec3828d1bc6d21bd7 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Tue, 2 Jul 2024 18:38:38 +0200 Subject: [PATCH 01/10] feat: in memory encryption parts repository --- go.mod | 40 +++++---- go.sum | 87 ++++++++++++------- internal/applications/shareapp/options.go | 9 +- .../ports/repositories/encryptionparts.go | 9 ++ internal/core/ports/services/share.go | 9 +- .../repositories/bunt/client.go | 22 +++++ .../bunt/encryptionpartsrepo/repo.go | 70 +++++++++++++++ .../mocks/encryptionpartsmockrepo/repo.go | 32 +++++++ 8 files changed, 228 insertions(+), 50 deletions(-) create mode 100644 internal/core/ports/repositories/encryptionparts.go create mode 100644 internal/infrastructure/repositories/bunt/client.go create mode 100644 internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go create mode 100644 internal/infrastructure/repositories/mocks/encryptionpartsmockrepo/repo.go diff --git a/go.mod b/go.mod index c22a939..c011d6f 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module go.openfort.xyz/shield go 1.22.0 require ( - github.com/MicahParks/keyfunc/v3 v3.2.9 + github.com/MicahParks/keyfunc/v3 v3.3.3 github.com/caarlos0/env/v10 v10.0.0 github.com/codahale/sss v0.0.0-20160501174526-0cb9f6d3f7f1 github.com/golang-jwt/jwt/v5 v5.2.1 @@ -11,25 +11,27 @@ require ( github.com/google/wire v0.6.0 github.com/gorilla/mux v1.8.1 github.com/pressly/goose v2.7.0+incompatible - github.com/rs/cors v1.10.1 - github.com/spf13/cobra v1.8.0 - github.com/stretchr/testify v1.8.4 - golang.org/x/crypto v0.21.0 - gorm.io/driver/mysql v1.5.5 - gorm.io/driver/postgres v1.5.7 - gorm.io/gorm v1.25.8 + github.com/rs/cors v1.11.0 + github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify v1.9.0 + github.com/tidwall/buntdb v1.3.1 + go.uber.org/ratelimit v0.3.1 + golang.org/x/crypto v0.24.0 + gorm.io/driver/mysql v1.5.7 + gorm.io/driver/postgres v1.5.9 + gorm.io/gorm v1.25.10 ) require ( filippo.io/edwards25519 v1.1.0 // indirect - github.com/MicahParks/jwkset v0.5.15 // indirect + github.com/MicahParks/jwkset v0.5.18 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.8.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -38,10 +40,16 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.5.0 // indirect - go.uber.org/ratelimit v0.3.1 // indirect - golang.org/x/sync v0.6.0 // indirect - golang.org/x/text v0.14.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + github.com/tidwall/btree v1.7.0 // indirect + github.com/tidwall/gjson v1.17.1 // indirect + github.com/tidwall/grect v0.1.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/rtred v0.1.2 // indirect + github.com/tidwall/tinyqueue v0.1.1 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7fb9e85..97524be 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,23 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/MicahParks/jwkset v0.5.15 h1:ACJY045Zuvo2TVWikeFLnKTIsEDQQHUHrNYiMW+gj24= -github.com/MicahParks/jwkset v0.5.15/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY= -github.com/MicahParks/keyfunc/v3 v3.2.9 h1:juKYzZvb5q4mWnox3439WNq6cusvSdt2fJ5nj+osgCk= -github.com/MicahParks/keyfunc/v3 v3.2.9/go.mod h1:Yx3jN/pn7ZMCxwFsyIrsmSqRfp0HGHAcyezBlhYi1Ew= +github.com/MicahParks/jwkset v0.5.18 h1:WLdyMngF7rCrnstQxA7mpRoxeaWqGzPM/0z40PJUK4w= +github.com/MicahParks/jwkset v0.5.18/go.mod h1:q8ptTGn/Z9c4MwbcfeCDssADeVQb3Pk7PnVxrvi+2QY= +github.com/MicahParks/keyfunc/v3 v3.3.3 h1:c6j9oSu1YUo0k//KwF1miIQlEMtqNlj7XBFLB8jtEmY= +github.com/MicahParks/keyfunc/v3 v3.3.3/go.mod h1:f/UMyXdKfkZzmBeBFUeYk+zu066J1Fcl48f7Wnl5Z48= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= github.com/codahale/sss v0.0.0-20160501174526-0cb9f6d3f7f1 h1:PJJtqFbZH8ZW9PtsfB+ALZKVPRiRwNbPrNe+gliLpGo= github.com/codahale/sss v0.0.0-20160501174526-0cb9f6d3f7f1/go.mod h1:0Vm/twPonvi1fkJ3kW8TbuttPQ4EyspL1xHUVr1I3uU= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-sql-driver/mysql v1.8.0 h1:UtktXaU2Nb64z/pLiGIxY4431SJ4/dR5cjMmlVHgnT4= -github.com/go-sql-driver/mysql v1.8.0/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -32,10 +32,10 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -54,32 +54,53 @@ github.com/pressly/goose v2.7.0+incompatible h1:PWejVEv07LCerQEzMMeAtjuyCKbyprZ/ github.com/pressly/goose v2.7.0+incompatible/go.mod h1:m+QHWCqxR3k8D9l7qfzuC/djtlfzxr34mozWDYEu1z8= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= -github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= +github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI= +github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +github.com/tidwall/buntdb v1.3.1 h1:HKoDF01/aBhl9RjYtbaLnvX9/OuenwvQiC3OP1CcL4o= +github.com/tidwall/buntdb v1.3.1/go.mod h1:lZZrZUWzlyDJKlLQ6DKAy53LnG7m5kHyrEHvvcDmBpU= +github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/grect v0.1.4 h1:dA3oIgNgWdSspFzn1kS4S/RDpZFLrIxAZOdJKjYapOg= +github.com/tidwall/grect v0.1.4/go.mod h1:9FBsaYRaR0Tcy4UwefBX/UDcDcDy9V5jUcxHzv2jd5Q= +github.com/tidwall/lotsa v1.0.2 h1:dNVBH5MErdaQ/xd9s769R31/n2dXavsQ0Yf4TMEHHw8= +github.com/tidwall/lotsa v1.0.2/go.mod h1:X6NiU+4yHA3fE3Puvpnn1XMDrFZrE9JO2/w+UMuqgR8= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/rtred v0.1.2 h1:exmoQtOLvDoO8ud++6LwVsAMTu0KPzLTUrMln8u1yu8= +github.com/tidwall/rtred v0.1.2/go.mod h1:hd69WNXQ5RP9vHd7dqekAz+RIdtfBogmglkZSRxCHFQ= +github.com/tidwall/tinyqueue v0.1.1 h1:SpNEvEggbpyN5DIReaJ2/1ndroY8iyEGxPYxoSaymYE= +github.com/tidwall/tinyqueue v0.1.1/go.mod h1:O/QNHwrnjqr6IHItYrzoHAKYhBkLI67Q096fQP5zMYw= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -95,8 +116,9 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -118,8 +140,9 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -135,10 +158,10 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.5.5 h1:WxklwX6FozMs1gk9yVadxGfjGiJjrBKPvIIvYZOMyws= -gorm.io/driver/mysql v1.5.5/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= +gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.8 h1:WAGEZ/aEcznN4D03laj8DKnehe1e9gYQAjW8xyPRdeo= -gorm.io/gorm v1.25.8/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= +gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/internal/applications/shareapp/options.go b/internal/applications/shareapp/options.go index 69a06c8..6fc0a29 100644 --- a/internal/applications/shareapp/options.go +++ b/internal/applications/shareapp/options.go @@ -1,7 +1,8 @@ package shareapp type options struct { - encryptionPart *string + encryptionPart *string + encryptionSession *string } type Option func(*options) @@ -11,3 +12,9 @@ func WithEncryptionPart(encryptionPart string) Option { o.encryptionPart = &encryptionPart } } + +func WithEncryptionSession(encryptionSession string) Option { + return func(o *options) { + o.encryptionSession = &encryptionSession + } +} diff --git a/internal/core/ports/repositories/encryptionparts.go b/internal/core/ports/repositories/encryptionparts.go new file mode 100644 index 0000000..0e31d54 --- /dev/null +++ b/internal/core/ports/repositories/encryptionparts.go @@ -0,0 +1,9 @@ +package repositories + +import "context" + +type EncryptionPartsRepository interface { + Get(ctx context.Context, sessionId string) (string, error) + Set(ctx context.Context, sessionId, part string) error + Delete(ctx context.Context, sessionId string) error +} diff --git a/internal/core/ports/services/share.go b/internal/core/ports/services/share.go index 32904b2..8d95063 100644 --- a/internal/core/ports/services/share.go +++ b/internal/core/ports/services/share.go @@ -13,7 +13,8 @@ type ShareService interface { type ShareOption func(*ShareOptions) type ShareOptions struct { - EncryptionKey *string + EncryptionKey *string + EncryptionSession *string } func WithEncryptionKey(key string) ShareOption { @@ -21,3 +22,9 @@ func WithEncryptionKey(key string) ShareOption { o.EncryptionKey = &key } } + +func WithEncryptionSession(session string) ShareOption { + return func(o *ShareOptions) { + o.EncryptionSession = &session + } +} diff --git a/internal/infrastructure/repositories/bunt/client.go b/internal/infrastructure/repositories/bunt/client.go new file mode 100644 index 0000000..37970aa --- /dev/null +++ b/internal/infrastructure/repositories/bunt/client.go @@ -0,0 +1,22 @@ +package bunt + +import "github.com/tidwall/buntdb" + +type Client struct { + *buntdb.DB +} + +func New() (*Client, error) { + db, err := buntdb.Open(":memory:") + if err != nil { + return nil, err + } + + return &Client{ + DB: db, + }, nil +} + +func (c *Client) Close() error { + return c.DB.Close() +} diff --git a/internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go b/internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go new file mode 100644 index 0000000..81fd99d --- /dev/null +++ b/internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go @@ -0,0 +1,70 @@ +package encryptionpartsrepo + +import ( + "context" + "errors" + "github.com/tidwall/buntdb" + "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/internal/infrastructure/repositories/bunt" + "go.openfort.xyz/shield/pkg/logger" + "log/slog" +) + +type repository struct { + db *bunt.Client + logger *slog.Logger +} + +var _ repositories.EncryptionPartsRepository = &repository{} + +func New(db *bunt.Client) repositories.EncryptionPartsRepository { + return &repository{ + db: db, + logger: logger.New("encryption_parts_repository"), + } +} + +func (r *repository) Get(ctx context.Context, sessionId string) (string, error) { + var part string + err := r.db.View(func(tx *buntdb.Tx) error { + var err error + part, err = tx.Get(sessionId) + return err + }) + if err != nil { + if errors.Is(err, buntdb.ErrNotFound) { + return "", domain.ErrEncryptionPartNotFound + } + r.logger.ErrorContext(ctx, "error getting encryption part", logger.Error(err)) + return "", err + } + + if part == "" { + return "", domain.ErrEncryptionPartNotFound + } + + return part, nil +} + +func (r *repository) Set(ctx context.Context, sessionId, part string) error { + return r.db.Update(func(tx *buntdb.Tx) error { + _, _, err := tx.Set(sessionId, part, nil) + if errors.Is(err, buntdb.ErrIndexExists) { + return domain.ErrEncryptionPartAlreadyExists + } + r.logger.ErrorContext(ctx, "error setting encryption part", logger.Error(err)) + return err + }) +} + +func (r *repository) Delete(ctx context.Context, sessionId string) error { + return r.db.Update(func(tx *buntdb.Tx) error { + _, err := tx.Delete(sessionId) + if errors.Is(err, buntdb.ErrNotFound) { + return domain.ErrEncryptionPartNotFound + } + r.logger.ErrorContext(ctx, "error deleting encryption part", logger.Error(err)) + return err + }) +} diff --git a/internal/infrastructure/repositories/mocks/encryptionpartsmockrepo/repo.go b/internal/infrastructure/repositories/mocks/encryptionpartsmockrepo/repo.go new file mode 100644 index 0000000..6da07ab --- /dev/null +++ b/internal/infrastructure/repositories/mocks/encryptionpartsmockrepo/repo.go @@ -0,0 +1,32 @@ +package encryptionpartsmockrepo + +import ( + "context" + + "github.com/stretchr/testify/mock" + "go.openfort.xyz/shield/internal/core/ports/repositories" +) + +type MockEncryptionPartsRepository struct { + mock.Mock +} + +var _ repositories.EncryptionPartsRepository = (*MockEncryptionPartsRepository)(nil) + +func (m *MockEncryptionPartsRepository) Get(ctx context.Context, sessionId string) (string, error) { + args := m.Mock.Called(ctx, sessionId) + if args.Get(0) == nil { + return "", args.Error(1) + } + return args.Get(0).(string), args.Error(1) +} + +func (m *MockEncryptionPartsRepository) Set(ctx context.Context, sessionId, part string) error { + args := m.Mock.Called(ctx, sessionId, part) + return args.Error(0) +} + +func (m *MockEncryptionPartsRepository) Delete(ctx context.Context, projectID string) error { + args := m.Mock.Called(ctx, projectID) + return args.Error(0) +} From df8b6466f87f9e4cc60b7e6b0a2be0aa7a350306 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Thu, 4 Jul 2024 13:01:50 +0200 Subject: [PATCH 02/10] refactor: factories --- cmd/cli/db.go | 2 +- di/wire.go | 23 ++--- di/wire_gen.go | 23 ++--- internal/adapters/authenticators/factory.go | 59 ++++++++++++ .../identity/custom_identity}/custom.go | 26 +++--- .../authenticators/identity}/errors.go | 2 +- .../authenticators/identity/factory.go | 65 +++++++++++++ .../identity/openfort_identity}/config.go | 2 +- .../identity/openfort_identity}/openfort.go | 67 ++++++------- .../project_authenticator.go | 49 ++++++++++ .../user_authenticator/user_authenticator.go | 60 ++++++++++++ .../handlers/rest/api/errors.go | 0 .../handlers/rest/authmdw/middleware.go | 79 +++++++++++----- .../handlers/rest/config.go | 0 .../handlers/rest/projecthdl/errors.go | 2 +- .../handlers/rest/projecthdl/handler.go | 2 +- .../handlers/rest/projecthdl/parser.go | 0 .../handlers/rest/projecthdl/types.go | 0 .../rest/ratelimitermdw/middleware.go | 0 .../handlers/rest/requestmdw/middleware.go | 0 .../handlers/rest/responsemdw/middleware.go | 0 .../handlers/rest/server.go | 14 +-- .../handlers/rest/sharehdl/errors.go | 2 +- .../handlers/rest/sharehdl/handler.go | 2 +- .../handlers/rest/sharehdl/parser.go | 0 .../handlers/rest/sharehdl/types.go | 0 .../handlers/rest/sharehdl/validator.go | 2 +- .../repositories/bunt/client.go | 0 .../bunt/encryptionpartsrepo/repo.go | 12 +-- .../mocks/encryptionpartsmockrepo/repo.go | 0 .../mocks/projectmockrepo/repo.go | 0 .../mocks/providermockrepo/repo.go | 9 ++ .../repositories/mocks/sharemockrepo/repo.go | 0 .../repositories/mocks/usermockedrepo/repo.go | 0 .../repositories/sql/client.go | 0 .../repositories/sql/config.go | 0 .../repositories/sql/errors.go | 0 .../sql/migrations/20240319132302_init.sql | 0 .../20240328110831_custom_domains.sql | 0 .../sql/migrations/20240404172915_entropy.sql | 0 .../20240405095925_encryption_parts.sql | 0 ...11120851_add_pem_certs_for_custom_auth.sql | 0 .../repositories/sql/projectrepo/parser.go | 0 .../repositories/sql/projectrepo/repo.go | 10 +- .../repositories/sql/projectrepo/types.go | 0 .../repositories/sql/providerrepo/parser.go | 0 .../repositories/sql/providerrepo/repo.go | 31 +++++-- .../repositories/sql/providerrepo/types.go | 0 .../repositories/sql/sharerepo/parser.go | 0 .../repositories/sql/sharerepo/repo.go | 6 +- .../repositories/sql/sharerepo/types.go | 0 .../repositories/sql/userrepo/options.go | 0 .../repositories/sql/userrepo/parser.go | 0 .../repositories/sql/userrepo/repo.go | 6 +- .../repositories/sql/userrepo/types.go | 0 internal/applications/projectapp/app.go | 10 +- internal/applications/projectapp/app_test.go | 32 +++---- internal/applications/projectapp/errors.go | 15 ++- internal/applications/shareapp/app_test.go | 26 +++--- internal/applications/shareapp/errors.go | 11 +-- .../domain/authentication/authentication.go | 6 ++ internal/core/domain/errors.go | 26 ------ internal/core/domain/errors/project.go | 10 ++ internal/core/domain/errors/provider.go | 10 ++ internal/core/domain/errors/share.go | 8 ++ internal/core/domain/errors/user.go | 9 ++ .../core/ports/authentication/apisecret.go | 7 -- internal/core/ports/authentication/user.go | 34 ------- .../core/ports/factories/authentication.go | 15 +++ internal/core/ports/factories/encryption.go | 75 +++++++++++++++ internal/core/ports/factories/identity.go | 15 +++ internal/core/ports/providers/provider.go | 26 ------ internal/core/ports/repositories/provider.go | 1 + internal/core/ports/services/share.go | 9 +- internal/core/ports/services/user.go | 5 +- internal/core/services/projectsvc/svc.go | 6 +- internal/core/services/projectsvc/svc_test.go | 10 +- internal/core/services/providersvc/svc.go | 8 +- .../core/services/providersvc/svc_test.go | 26 +++--- internal/core/services/sharesvc/svc.go | 8 +- internal/core/services/sharesvc/svc_test.go | 18 ++-- internal/core/services/usersvc/svc.go | 55 +++++++---- internal/core/services/usersvc/svc_test.go | 28 +++--- .../authenticationmgr/apisecret.go | 43 --------- .../authenticationmgr/manager.go | 89 ------------------ .../infrastructure/authenticationmgr/user.go | 93 ------------------- .../infrastructure/providersmgr/manager.go | 57 ------------ pkg/jwk/errors.go | 7 ++ .../providersmgr/jwks.go => pkg/jwk/jwk.go | 4 +- 89 files changed, 714 insertions(+), 643 deletions(-) create mode 100644 internal/adapters/authenticators/factory.go rename internal/{infrastructure/providersmgr => adapters/authenticators/identity/custom_identity}/custom.go (63%) rename internal/{infrastructure/providersmgr => adapters/authenticators/identity}/errors.go (96%) create mode 100644 internal/adapters/authenticators/identity/factory.go rename internal/{infrastructure/providersmgr => adapters/authenticators/identity/openfort_identity}/config.go (90%) rename internal/{infrastructure/providersmgr => adapters/authenticators/identity/openfort_identity}/openfort.go (57%) create mode 100644 internal/adapters/authenticators/project_authenticator/project_authenticator.go create mode 100644 internal/adapters/authenticators/user_authenticator/user_authenticator.go rename internal/{infrastructure => adapters}/handlers/rest/api/errors.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/authmdw/middleware.go (51%) rename internal/{infrastructure => adapters}/handlers/rest/config.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/projecthdl/errors.go (95%) rename internal/{infrastructure => adapters}/handlers/rest/projecthdl/handler.go (99%) rename internal/{infrastructure => adapters}/handlers/rest/projecthdl/parser.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/projecthdl/types.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/ratelimitermdw/middleware.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/requestmdw/middleware.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/responsemdw/middleware.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/server.go (87%) rename internal/{infrastructure => adapters}/handlers/rest/sharehdl/errors.go (93%) rename internal/{infrastructure => adapters}/handlers/rest/sharehdl/handler.go (98%) rename internal/{infrastructure => adapters}/handlers/rest/sharehdl/parser.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/sharehdl/types.go (100%) rename internal/{infrastructure => adapters}/handlers/rest/sharehdl/validator.go (93%) rename internal/{infrastructure => adapters}/repositories/bunt/client.go (100%) rename internal/{infrastructure => adapters}/repositories/bunt/encryptionpartsrepo/repo.go (81%) rename internal/{infrastructure => adapters}/repositories/mocks/encryptionpartsmockrepo/repo.go (100%) rename internal/{infrastructure => adapters}/repositories/mocks/projectmockrepo/repo.go (100%) rename internal/{infrastructure => adapters}/repositories/mocks/providermockrepo/repo.go (89%) rename internal/{infrastructure => adapters}/repositories/mocks/sharemockrepo/repo.go (100%) rename internal/{infrastructure => adapters}/repositories/mocks/usermockedrepo/repo.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/client.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/config.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/errors.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/migrations/20240319132302_init.sql (100%) rename internal/{infrastructure => adapters}/repositories/sql/migrations/20240328110831_custom_domains.sql (100%) rename internal/{infrastructure => adapters}/repositories/sql/migrations/20240404172915_entropy.sql (100%) rename internal/{infrastructure => adapters}/repositories/sql/migrations/20240405095925_encryption_parts.sql (100%) rename internal/{infrastructure => adapters}/repositories/sql/migrations/20240411120851_add_pem_certs_for_custom_auth.sql (100%) rename internal/{infrastructure => adapters}/repositories/sql/projectrepo/parser.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/projectrepo/repo.go (91%) rename internal/{infrastructure => adapters}/repositories/sql/projectrepo/types.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/providerrepo/parser.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/providerrepo/repo.go (83%) rename internal/{infrastructure => adapters}/repositories/sql/providerrepo/types.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/sharerepo/parser.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/sharerepo/repo.go (94%) rename internal/{infrastructure => adapters}/repositories/sql/sharerepo/types.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/userrepo/options.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/userrepo/parser.go (100%) rename internal/{infrastructure => adapters}/repositories/sql/userrepo/repo.go (93%) rename internal/{infrastructure => adapters}/repositories/sql/userrepo/types.go (100%) create mode 100644 internal/core/domain/authentication/authentication.go delete mode 100644 internal/core/domain/errors.go create mode 100644 internal/core/domain/errors/project.go create mode 100644 internal/core/domain/errors/provider.go create mode 100644 internal/core/domain/errors/share.go create mode 100644 internal/core/domain/errors/user.go delete mode 100644 internal/core/ports/authentication/apisecret.go delete mode 100644 internal/core/ports/authentication/user.go create mode 100644 internal/core/ports/factories/authentication.go create mode 100644 internal/core/ports/factories/encryption.go create mode 100644 internal/core/ports/factories/identity.go delete mode 100644 internal/core/ports/providers/provider.go delete mode 100644 internal/infrastructure/authenticationmgr/apisecret.go delete mode 100644 internal/infrastructure/authenticationmgr/manager.go delete mode 100644 internal/infrastructure/authenticationmgr/user.go delete mode 100644 internal/infrastructure/providersmgr/manager.go create mode 100644 pkg/jwk/errors.go rename internal/infrastructure/providersmgr/jwks.go => pkg/jwk/jwk.go (81%) diff --git a/cmd/cli/db.go b/cmd/cli/db.go index f683e0a..4364d45 100644 --- a/cmd/cli/db.go +++ b/cmd/cli/db.go @@ -3,7 +3,7 @@ package cli import ( "github.com/spf13/cobra" "go.openfort.xyz/shield/di" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" ) func NewCmdDB() *cobra.Command { diff --git a/di/wire.go b/di/wire.go index 2245bab..18cb567 100644 --- a/di/wire.go +++ b/di/wire.go @@ -5,6 +5,15 @@ package di import ( "github.com/google/wire" + "go.openfort.xyz/shield/internal/adapters/authenticationmgr" + identity2 "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + "go.openfort.xyz/shield/internal/adapters/handlers/rest" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/projectrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/providerrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/sharerepo" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/userrepo" "go.openfort.xyz/shield/internal/applications/projectapp" "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -13,14 +22,6 @@ import ( "go.openfort.xyz/shield/internal/core/services/providersvc" "go.openfort.xyz/shield/internal/core/services/sharesvc" "go.openfort.xyz/shield/internal/core/services/usersvc" - "go.openfort.xyz/shield/internal/infrastructure/authenticationmgr" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest" - "go.openfort.xyz/shield/internal/infrastructure/providersmgr" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/projectrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/providerrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/sharerepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/userrepo" ) func ProvideSQL() (c *sql.Client, err error) { @@ -104,10 +105,10 @@ func ProvideShareService() (s services.ShareService, err error) { return } -func ProvideProviderManager() (pm *providersmgr.Manager, err error) { +func ProvideProviderManager() (pm *identity2.identityFactory, err error) { wire.Build( - providersmgr.NewManager, - providersmgr.GetConfigFromEnv, + identity2.NewIdentityFactory, + openfort_identity.GetConfigFromEnv, ProvideSQLProviderRepository, ) diff --git a/di/wire_gen.go b/di/wire_gen.go index 1da86c9..fa352a1 100644 --- a/di/wire_gen.go +++ b/di/wire_gen.go @@ -7,6 +7,15 @@ package di import ( + "go.openfort.xyz/shield/internal/adapters/authenticationmgr" + identity2 "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + "go.openfort.xyz/shield/internal/adapters/handlers/rest" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/projectrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/providerrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/sharerepo" + "go.openfort.xyz/shield/internal/adapters/repositories/sql/userrepo" "go.openfort.xyz/shield/internal/applications/projectapp" "go.openfort.xyz/shield/internal/applications/shareapp" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -15,14 +24,6 @@ import ( "go.openfort.xyz/shield/internal/core/services/providersvc" "go.openfort.xyz/shield/internal/core/services/sharesvc" "go.openfort.xyz/shield/internal/core/services/usersvc" - "go.openfort.xyz/shield/internal/infrastructure/authenticationmgr" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest" - "go.openfort.xyz/shield/internal/infrastructure/providersmgr" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/projectrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/providerrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/sharerepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql/userrepo" ) // Injectors from wire.go: @@ -111,8 +112,8 @@ func ProvideShareService() (services.ShareService, error) { return shareService, nil } -func ProvideProviderManager() (*providersmgr.Manager, error) { - config, err := providersmgr.GetConfigFromEnv() +func ProvideProviderManager() (*identity2.identityFactory, error) { + config, err := openfort_identity.GetConfigFromEnv() if err != nil { return nil, err } @@ -120,7 +121,7 @@ func ProvideProviderManager() (*providersmgr.Manager, error) { if err != nil { return nil, err } - manager := providersmgr.NewManager(config, providerRepository) + manager := identity2.NewIdentityFactory(config, providerRepository) return manager, nil } diff --git a/internal/adapters/authenticators/factory.go b/internal/adapters/authenticators/factory.go new file mode 100644 index 0000000..b810a09 --- /dev/null +++ b/internal/adapters/authenticators/factory.go @@ -0,0 +1,59 @@ +package authenticators + +import ( + "go.openfort.xyz/shield/internal/adapters/authenticators/project_authenticator" + "go.openfort.xyz/shield/internal/adapters/authenticators/user_authenticator" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/internal/core/ports/services" +) + +type authenticatorFactory struct { + projectRepo repositories.ProjectRepository + userService services.UserService +} + +func NewAuthenticatorFactory(projectRepo repositories.ProjectRepository, userService services.UserService) factories.AuthenticationFactory { + return &authenticatorFactory{ + projectRepo: projectRepo, + userService: userService, + } +} + +func (f *authenticatorFactory) CreateProjectAuthenticator(apiKey, apiSecret string) factories.Authenticator { + return project_authenticator.NewProjectAuthenticator(f.projectRepo, apiKey, apiSecret) +} + +func (f *authenticatorFactory) CreateUserAuthenticator(apiKey, token string, identityFactory factories.Identity) factories.Authenticator { + return user_authenticator.NewUserAuthenticator(f.projectRepo, f.userService, apiKey, token, identityFactory) +} + +// func (m *Manager) PreRegisterUser(ctx context.Context, userID string, providerType provider.Type) (string, error) { +// projID := contexter.GetProjectID(ctx) +// prov, err := m.providerManager.GetProvider(ctx, projID, providerType) +// if err != nil { +// m.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) +// return "", err +// } +// +// usr, err := m.userService.GetByExternal(ctx, userID, prov.GetProviderID()) +// if err != nil { +// if !errors.Is(err, domainErrors.ErrUserNotFound) && !errors.Is(err, domainErrors.ErrExternalUserNotFound) { +// m.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) +// return "", err +// } +// usr, err = m.userService.Create(ctx, projID) +// if err != nil { +// m.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) +// return "", err +// } +// +// _, err = m.userService.CreateExternal(ctx, projID, usr.ID, userID, prov.GetProviderID()) +// if err != nil { +// m.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) +// return "", err +// } +// } +// +// return usr.ID, nil +//} diff --git a/internal/infrastructure/providersmgr/custom.go b/internal/adapters/authenticators/identity/custom_identity/custom.go similarity index 63% rename from internal/infrastructure/providersmgr/custom.go rename to internal/adapters/authenticators/identity/custom_identity/custom.go index 0950c80..7ffcd36 100644 --- a/internal/infrastructure/providersmgr/custom.go +++ b/internal/adapters/authenticators/identity/custom_identity/custom.go @@ -1,35 +1,37 @@ -package providersmgr +package custom_identity import ( "context" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/pkg/jwk" "log/slog" "github.com/golang-jwt/jwt/v5" "go.openfort.xyz/shield/internal/core/domain/provider" - "go.openfort.xyz/shield/internal/core/ports/providers" "go.openfort.xyz/shield/pkg/logger" ) -type custom struct { +type CustomIdentityFactory struct { config *provider.CustomConfig logger *slog.Logger } -var _ providers.IdentityProvider = (*custom)(nil) +var _ factories.Identity = (*CustomIdentityFactory)(nil) -func newCustomProvider(providerConfig *provider.CustomConfig) providers.IdentityProvider { - return &custom{ +func NewCustomIdentityFactory(providerConfig *provider.CustomConfig) factories.Identity { + return &CustomIdentityFactory{ config: providerConfig, logger: logger.New("custom_provider"), } } -func (c *custom) GetProviderID() string { +func (c *CustomIdentityFactory) GetProviderID() string { return c.config.ProviderID } -func (c *custom) Identify(ctx context.Context, token string, _ ...providers.CustomOption) (string, error) { +func (c *CustomIdentityFactory) Identify(ctx context.Context, token string) (string, error) { c.logger.InfoContext(ctx, "identifying user") var externalUserID string @@ -38,9 +40,9 @@ func (c *custom) Identify(ctx context.Context, token string, _ ...providers.Cust case c.config.PEM != "" && c.config.KeyType != provider.KeyTypeUnknown: externalUserID, err = c.validatePEM(token) case c.config.JWK != "": - externalUserID, err = validateJWKs(token, c.config.JWK) + externalUserID, err = jwk.Validate(token, c.config.JWK) // TODO parse error default: - return "", ErrProviderMisconfigured + return "", identity.ErrProviderMisconfigured } if err != nil { c.logger.ErrorContext(ctx, "failed to validate jwt", logger.Error(err)) @@ -50,7 +52,7 @@ func (c *custom) Identify(ctx context.Context, token string, _ ...providers.Cust return externalUserID, nil } -func (c *custom) validatePEM(token string) (string, error) { +func (c *CustomIdentityFactory) validatePEM(token string) (string, error) { var keyFunc jwt.Keyfunc switch c.config.KeyType { case provider.KeyTypeRSA: @@ -66,7 +68,7 @@ func (c *custom) validatePEM(token string) (string, error) { return jwt.ParseEdPublicKeyFromPEM([]byte(c.config.PEM)) } default: - return "", ErrCertTypeNotSupported + return "", identity.ErrCertTypeNotSupported } parsed, err := jwt.Parse(token, keyFunc) diff --git a/internal/infrastructure/providersmgr/errors.go b/internal/adapters/authenticators/identity/errors.go similarity index 96% rename from internal/infrastructure/providersmgr/errors.go rename to internal/adapters/authenticators/identity/errors.go index 38cd848..50d1dae 100644 --- a/internal/infrastructure/providersmgr/errors.go +++ b/internal/adapters/authenticators/identity/errors.go @@ -1,4 +1,4 @@ -package providersmgr +package identity import "errors" diff --git a/internal/adapters/authenticators/identity/factory.go b/internal/adapters/authenticators/identity/factory.go new file mode 100644 index 0000000..66dd784 --- /dev/null +++ b/internal/adapters/authenticators/identity/factory.go @@ -0,0 +1,65 @@ +package identity + +import ( + "context" + "errors" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity/custom_identity" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/factories" + "log/slog" + + "go.openfort.xyz/shield/internal/core/domain/provider" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/pkg/logger" +) + +type identityFactory struct { + config *openfort_identity.Config + repo repositories.ProviderRepository + logger *slog.Logger +} + +func NewIdentityFactory(cfg *openfort_identity.Config, repo repositories.ProviderRepository) factories.IdentityFactory { + return &identityFactory{ + config: cfg, + repo: repo, + logger: logger.New("provider_manager"), + } +} + +func (p *identityFactory) CreateCustomIdentity(ctx context.Context, apiKey string) (factories.Identity, error) { + prov, err := p.repo.GetByAPIKeyAndType(ctx, apiKey, provider.TypeCustom) + if err != nil { + if errors.Is(err, domainErrors.ErrProjectNotFound) { + return nil, ErrProviderNotConfigured + } + p.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) + return nil, err + } + + config, ok := prov.Config.(*provider.CustomConfig) + if !ok { + return nil, ErrProviderConfigMismatch + } + + return custom_identity.NewCustomIdentityFactory(config), nil +} + +func (p *identityFactory) CreateOpenfortIdentity(ctx context.Context, apiKey string, authenticationProvider, tokenType *string) (factories.Identity, error) { + prov, err := p.repo.GetByAPIKeyAndType(ctx, apiKey, provider.TypeOpenfort) + if err != nil { + if errors.Is(err, domainErrors.ErrProjectNotFound) { + return nil, ErrProviderNotConfigured + } + p.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) + return nil, err + } + + config, ok := prov.Config.(*provider.OpenfortConfig) + if !ok { + return nil, ErrProviderConfigMismatch + } + + return openfort_identity.NewOpenfortIdentityFactory(p.config, config, authenticationProvider, tokenType), nil +} diff --git a/internal/infrastructure/providersmgr/config.go b/internal/adapters/authenticators/identity/openfort_identity/config.go similarity index 90% rename from internal/infrastructure/providersmgr/config.go rename to internal/adapters/authenticators/identity/openfort_identity/config.go index 1dc0a03..f0ec451 100644 --- a/internal/infrastructure/providersmgr/config.go +++ b/internal/adapters/authenticators/identity/openfort_identity/config.go @@ -1,4 +1,4 @@ -package providersmgr +package openfort_identity import "github.com/caarlos0/env/v10" diff --git a/internal/infrastructure/providersmgr/openfort.go b/internal/adapters/authenticators/identity/openfort_identity/openfort.go similarity index 57% rename from internal/infrastructure/providersmgr/openfort.go rename to internal/adapters/authenticators/identity/openfort_identity/openfort.go index db6ffc4..f7719b6 100644 --- a/internal/infrastructure/providersmgr/openfort.go +++ b/internal/adapters/authenticators/identity/openfort_identity/openfort.go @@ -1,79 +1,70 @@ -package providersmgr +package openfort_identity import ( "bytes" "context" "encoding/json" - "errors" "fmt" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/pkg/jwk" "io" "log/slog" "net/http" "time" "go.openfort.xyz/shield/internal/core/domain/provider" - "go.openfort.xyz/shield/internal/core/ports/providers" "go.openfort.xyz/shield/pkg/logger" ) -type openfort struct { +type OpenfortIdentityFactory struct { publishableKey string baseURL string providerID string logger *slog.Logger + + authenticationProvider *string + tokenType *string } -var _ providers.IdentityProvider = (*openfort)(nil) +var _ factories.Identity = (*OpenfortIdentityFactory)(nil) -func newOpenfortProvider(config *Config, providerConfig *provider.OpenfortConfig) providers.IdentityProvider { - return &openfort{ - publishableKey: providerConfig.PublishableKey, - providerID: providerConfig.ProviderID, - baseURL: config.OpenfortBaseURL, - logger: logger.New("openfort_provider"), +func NewOpenfortIdentityFactory(config *Config, providerConfig *provider.OpenfortConfig, authenticationProvider, tokenType *string) factories.Identity { + return &OpenfortIdentityFactory{ + publishableKey: providerConfig.PublishableKey, + providerID: providerConfig.ProviderID, + baseURL: config.OpenfortBaseURL, + logger: logger.New("openfort_provider"), + authenticationProvider: authenticationProvider, + tokenType: tokenType, } } -func (o *openfort) GetProviderID() string { +func (o *OpenfortIdentityFactory) GetProviderID() string { return o.providerID } -func (o *openfort) Identify(ctx context.Context, token string, opts ...providers.CustomOption) (string, error) { +func (o *OpenfortIdentityFactory) Identify(ctx context.Context, token string) (string, error) { o.logger.InfoContext(ctx, "identifying user") - userID, err := validateJWKs(token, fmt.Sprintf("%s/iam/v1/%s/jwks.json", o.baseURL, o.publishableKey)) - if err != nil { - if !errors.Is(err, ErrInvalidToken) { - o.logger.ErrorContext(ctx, "failed to validate jwks", logger.Error(err)) - return "", err - } - - return o.identifyOAuth(ctx, token, opts...) + if o.authenticationProvider != nil && o.tokenType != nil { + return o.thirdParty(ctx, token, *o.authenticationProvider, *o.tokenType) } - return userID, nil + return o.accessToken(ctx, token) } -func (o *openfort) identifyOAuth(ctx context.Context, token string, opts ...providers.CustomOption) (string, error) { - var opt providers.CustomOptions - for _, o := range opts { - o(&opt) - } - - if opt.OpenfortProvider == nil { - return "", ErrMissingOpenfortProvider - } - - if opt.OpenfortTokenType == nil { - return "", ErrMissingOpenfortTokenType - } +func (o *OpenfortIdentityFactory) accessToken(ctx context.Context, token string) (string, error) { + return jwk.Validate(token, fmt.Sprintf("%s/iam/v1/%s/jwks.json", o.baseURL, o.publishableKey)) // TODO parse error +} +func (o *OpenfortIdentityFactory) thirdParty(ctx context.Context, token, authenticationProvider, tokenType string) (string, error) { url := fmt.Sprintf("%s/iam/v1/oauth/authenticate", o.baseURL) reqBody := authenticateOauthRequest{ - Provider: *opt.OpenfortProvider, + Provider: authenticationProvider, Token: token, - TokenType: *opt.OpenfortTokenType, + TokenType: tokenType, } rawReqBody, err := json.Marshal(reqBody) @@ -96,7 +87,7 @@ func (o *openfort) identifyOAuth(ctx context.Context, token string, opts ...prov defer resp.Body.Close() if resp.StatusCode/100 != 2 { - return "", ErrUnexpectedStatusCode + return "", identity.ErrUnexpectedStatusCode } rawResponse, err := io.ReadAll(resp.Body) diff --git a/internal/adapters/authenticators/project_authenticator/project_authenticator.go b/internal/adapters/authenticators/project_authenticator/project_authenticator.go new file mode 100644 index 0000000..716cdc5 --- /dev/null +++ b/internal/adapters/authenticators/project_authenticator/project_authenticator.go @@ -0,0 +1,49 @@ +package project_authenticator + +import ( + "context" + "go.openfort.xyz/shield/internal/core/domain/authentication" + "go.openfort.xyz/shield/internal/core/ports/factories" + "log/slog" + + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/pkg/logger" + "golang.org/x/crypto/bcrypt" +) + +type ProjectAuthenticator struct { + projectRepo repositories.ProjectRepository + apiKey, apiSecret string + logger *slog.Logger +} + +var _ factories.Authenticator = (*ProjectAuthenticator)(nil) + +func NewProjectAuthenticator(repository repositories.ProjectRepository, apiKey, apiSecret string) factories.Authenticator { + return &ProjectAuthenticator{ + projectRepo: repository, + apiKey: apiKey, + apiSecret: apiSecret, + logger: logger.New("api_key_authenticator"), + } +} + +func (a *ProjectAuthenticator) Authenticate(ctx context.Context) (*authentication.Authentication, error) { + a.logger.InfoContext(ctx, "authenticating api key") + + proj, err := a.projectRepo.GetByAPIKey(ctx, a.apiKey) + if err != nil { + a.logger.ErrorContext(ctx, "failed to authenticate api key", logger.Error(err)) + return nil, err + } + + err = bcrypt.CompareHashAndPassword([]byte(proj.APISecret), []byte(a.apiSecret)) + if err != nil { + a.logger.ErrorContext(ctx, "failed to authenticate api secret", logger.Error(err)) + return nil, err + } + + return &authentication.Authentication{ + ProjectID: proj.ID, + }, nil +} diff --git a/internal/adapters/authenticators/user_authenticator/user_authenticator.go b/internal/adapters/authenticators/user_authenticator/user_authenticator.go new file mode 100644 index 0000000..93b0f3f --- /dev/null +++ b/internal/adapters/authenticators/user_authenticator/user_authenticator.go @@ -0,0 +1,60 @@ +package user_authenticator + +import ( + "context" + "go.openfort.xyz/shield/internal/core/domain/authentication" + "go.openfort.xyz/shield/internal/core/ports/factories" + "log/slog" + + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/internal/core/ports/services" + "go.openfort.xyz/shield/pkg/logger" +) + +type UserAuthenticator struct { + projectRepo repositories.ProjectRepository + userService services.UserService + apiKey, token string + identityFactory factories.Identity + logger *slog.Logger +} + +var _ factories.Authenticator = (*UserAuthenticator)(nil) + +func NewUserAuthenticator(repository repositories.ProjectRepository, userService services.UserService, apiKey, token string, identityFactory factories.Identity) factories.Authenticator { + return &UserAuthenticator{ + projectRepo: repository, + userService: userService, + apiKey: apiKey, + token: token, + identityFactory: identityFactory, + logger: logger.New("api_key_authenticator"), + } +} + +func (a *UserAuthenticator) Authenticate(ctx context.Context) (*authentication.Authentication, error) { + a.logger.InfoContext(ctx, "authenticating api key") + + proj, err := a.projectRepo.GetByAPIKey(ctx, a.apiKey) + if err != nil { + a.logger.ErrorContext(ctx, "failed to authenticate api key", logger.Error(err)) + return nil, err + } + + externalUserID, err := a.identityFactory.Identify(ctx, a.token) + if err != nil { + a.logger.ErrorContext(ctx, "failed to identify user", logger.Error(err)) + return nil, err + } + + usr, err := a.userService.GetOrCreate(ctx, proj.ID, externalUserID, a.identityFactory.GetProviderID()) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get or create user", logger.Error(err)) + return nil, err + } + + return &authentication.Authentication{ + UserID: usr.ID, + ProjectID: proj.ID, + }, nil +} diff --git a/internal/infrastructure/handlers/rest/api/errors.go b/internal/adapters/handlers/rest/api/errors.go similarity index 100% rename from internal/infrastructure/handlers/rest/api/errors.go rename to internal/adapters/handlers/rest/api/errors.go diff --git a/internal/infrastructure/handlers/rest/authmdw/middleware.go b/internal/adapters/handlers/rest/authmdw/middleware.go similarity index 51% rename from internal/infrastructure/handlers/rest/authmdw/middleware.go rename to internal/adapters/handlers/rest/authmdw/middleware.go index 0905891..ab69a94 100644 --- a/internal/infrastructure/handlers/rest/authmdw/middleware.go +++ b/internal/adapters/handlers/rest/authmdw/middleware.go @@ -1,12 +1,12 @@ package authmdw import ( + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/ports/services" "net/http" "strings" - authenticate "go.openfort.xyz/shield/internal/core/ports/authentication" - "go.openfort.xyz/shield/internal/infrastructure/authenticationmgr" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/pkg/contexter" ) @@ -19,14 +19,20 @@ const OpenfortTokenTypeHeader = "X-Openfort-Token-Type" //nolint:go const AccessControlAllowOriginHeader = "Access-Control-Allow-Origin" //nolint:gosec const EncryptionPartHeader = "X-Encryption-Part" //nolint:gosec const UserIDHeader = "X-User-ID" //nolint:gosec +const AuthenticationTypeCustom = "custom" //nolint:gosec +const AuthenticationTypeOpenfort = "openfort" //nolint:gosec type Middleware struct { - manager *authenticationmgr.Manager + authenticationFactory factories.AuthenticationFactory + identityFactory factories.IdentityFactory + userService services.UserService } -func New(manager *authenticationmgr.Manager) *Middleware { +func New(authenticationFactory factories.AuthenticationFactory, identityFactory factories.IdentityFactory, userService services.UserService) *Middleware { return &Middleware{ - manager: manager, + authenticationFactory: authenticationFactory, + identityFactory: identityFactory, + userService: userService, } } @@ -44,19 +50,26 @@ func (m *Middleware) AuthenticateAPISecret(next http.Handler) http.Handler { return } - projectID, err := m.manager.GetAPISecretAuthenticator().Authenticate(r.Context(), apiKey, apiSecret) + authenticator := m.authenticationFactory.CreateProjectAuthenticator(apiKey, apiSecret) + authentication, err := authenticator.Authenticate(r.Context()) if err != nil { api.RespondWithError(w, api.ErrInvalidAPISecret) return } - ctx := contexter.WithProjectID(r.Context(), projectID) + ctx := contexter.WithProjectID(r.Context(), authentication.ProjectID) next.ServeHTTP(w, r.WithContext(ctx)) }) } func (m *Middleware) PreRegisterUser(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiKey := r.Header.Get(APIKeyHeader) + if apiKey == "" { + api.RespondWithError(w, api.ErrMissingAPIKey) + return + } + userID := r.Header.Get(UserIDHeader) if userID == "" { api.RespondWithError(w, api.ErrMissingUserID) @@ -69,19 +82,28 @@ func (m *Middleware) PreRegisterUser(next http.Handler) http.Handler { return } - provider, err := m.manager.GetAuthProvider(providerStr) + var identity factories.Identity + var err error + if providerStr == AuthenticationTypeCustom { + identity, err = m.identityFactory.CreateCustomIdentity(r.Context(), apiKey) + } else if providerStr == AuthenticationTypeOpenfort { + identity, err = m.identityFactory.CreateOpenfortIdentity(r.Context(), apiKey, nil, nil) + } else { + api.RespondWithError(w, api.ErrInvalidAuthProvider) + return + } if err != nil { api.RespondWithError(w, api.ErrInvalidAuthProvider) return } - usr, err := m.manager.PreRegisterUser(r.Context(), userID, provider) + usr, err := m.userService.GetOrCreate(r.Context(), contexter.GetProjectID(r.Context()), userID, identity.GetProviderID()) if err != nil { api.RespondWithError(w, api.ErrPreRegisterUser) return } - ctx := contexter.WithUserID(r.Context(), usr) + ctx := contexter.WithUserID(r.Context(), usr.ID) next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -114,29 +136,40 @@ func (m *Middleware) AuthenticateUser(next http.Handler) http.Handler { return } - openfortProvider := r.Header.Get(OpenfortProviderHeader) - openfortTokenType := r.Header.Get(OpenfortTokenTypeHeader) - - var customOptions []authenticate.CustomOption - if openfortProvider != "" && openfortTokenType != "" { - customOptions = append(customOptions, authenticate.WithOpenfortProvider(openfortProvider)) - customOptions = append(customOptions, authenticate.WithOpenfortTokenType(openfortTokenType)) + var identity factories.Identity + var err error + if providerStr == AuthenticationTypeCustom { + identity, err = m.identityFactory.CreateCustomIdentity(r.Context(), apiKey) + } else if providerStr == AuthenticationTypeOpenfort { + var openfortProvider *string + if r.Header.Get(OpenfortProviderHeader) != "" { + openfortProvider = new(string) + *openfortProvider = r.Header.Get(OpenfortProviderHeader) + } + var openfortTokenType *string + if r.Header.Get(OpenfortTokenTypeHeader) != "" { + openfortTokenType = new(string) + *openfortTokenType = r.Header.Get(OpenfortTokenTypeHeader) + } + identity, err = m.identityFactory.CreateOpenfortIdentity(r.Context(), apiKey, openfortProvider, openfortTokenType) + } else { + api.RespondWithError(w, api.ErrInvalidAuthProvider) + return } - - provider, err := m.manager.GetAuthProvider(providerStr) if err != nil { api.RespondWithError(w, api.ErrInvalidAuthProvider) return } - auth, err := m.manager.GetUserAuthenticator().Authenticate(r.Context(), apiKey, token, provider, customOptions...) + authenticator := m.authenticationFactory.CreateUserAuthenticator(apiKey, token, identity) + authentication, err := authenticator.Authenticate(r.Context()) if err != nil { api.RespondWithError(w, api.ErrInvalidToken) return } - ctx := contexter.WithUserID(r.Context(), auth.UserID) - ctx = contexter.WithProjectID(ctx, auth.ProjectID) + ctx := contexter.WithUserID(r.Context(), authentication.UserID) + ctx = contexter.WithProjectID(ctx, authentication.ProjectID) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/infrastructure/handlers/rest/config.go b/internal/adapters/handlers/rest/config.go similarity index 100% rename from internal/infrastructure/handlers/rest/config.go rename to internal/adapters/handlers/rest/config.go diff --git a/internal/infrastructure/handlers/rest/projecthdl/errors.go b/internal/adapters/handlers/rest/projecthdl/errors.go similarity index 95% rename from internal/infrastructure/handlers/rest/projecthdl/errors.go rename to internal/adapters/handlers/rest/projecthdl/errors.go index 8560506..9bf9c75 100644 --- a/internal/infrastructure/handlers/rest/projecthdl/errors.go +++ b/internal/adapters/handlers/rest/projecthdl/errors.go @@ -3,8 +3,8 @@ package projecthdl import ( "errors" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/internal/applications/projectapp" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" ) func fromApplicationError(err error) *api.Error { diff --git a/internal/infrastructure/handlers/rest/projecthdl/handler.go b/internal/adapters/handlers/rest/projecthdl/handler.go similarity index 99% rename from internal/infrastructure/handlers/rest/projecthdl/handler.go rename to internal/adapters/handlers/rest/projecthdl/handler.go index c5af920..fd16818 100644 --- a/internal/infrastructure/handlers/rest/projecthdl/handler.go +++ b/internal/adapters/handlers/rest/projecthdl/handler.go @@ -7,8 +7,8 @@ import ( "net/http" "github.com/gorilla/mux" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/internal/applications/projectapp" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" "go.openfort.xyz/shield/pkg/logger" ) diff --git a/internal/infrastructure/handlers/rest/projecthdl/parser.go b/internal/adapters/handlers/rest/projecthdl/parser.go similarity index 100% rename from internal/infrastructure/handlers/rest/projecthdl/parser.go rename to internal/adapters/handlers/rest/projecthdl/parser.go diff --git a/internal/infrastructure/handlers/rest/projecthdl/types.go b/internal/adapters/handlers/rest/projecthdl/types.go similarity index 100% rename from internal/infrastructure/handlers/rest/projecthdl/types.go rename to internal/adapters/handlers/rest/projecthdl/types.go diff --git a/internal/infrastructure/handlers/rest/ratelimitermdw/middleware.go b/internal/adapters/handlers/rest/ratelimitermdw/middleware.go similarity index 100% rename from internal/infrastructure/handlers/rest/ratelimitermdw/middleware.go rename to internal/adapters/handlers/rest/ratelimitermdw/middleware.go diff --git a/internal/infrastructure/handlers/rest/requestmdw/middleware.go b/internal/adapters/handlers/rest/requestmdw/middleware.go similarity index 100% rename from internal/infrastructure/handlers/rest/requestmdw/middleware.go rename to internal/adapters/handlers/rest/requestmdw/middleware.go diff --git a/internal/infrastructure/handlers/rest/responsemdw/middleware.go b/internal/adapters/handlers/rest/responsemdw/middleware.go similarity index 100% rename from internal/infrastructure/handlers/rest/responsemdw/middleware.go rename to internal/adapters/handlers/rest/responsemdw/middleware.go diff --git a/internal/infrastructure/handlers/rest/server.go b/internal/adapters/handlers/rest/server.go similarity index 87% rename from internal/infrastructure/handlers/rest/server.go rename to internal/adapters/handlers/rest/server.go index 64d8982..f055f1d 100644 --- a/internal/infrastructure/handlers/rest/server.go +++ b/internal/adapters/handlers/rest/server.go @@ -9,15 +9,15 @@ import ( "github.com/gorilla/mux" "github.com/rs/cors" + "go.openfort.xyz/shield/internal/adapters/authenticationmgr" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/authmdw" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/projecthdl" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/ratelimitermdw" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/requestmdw" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/responsemdw" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/sharehdl" "go.openfort.xyz/shield/internal/applications/projectapp" "go.openfort.xyz/shield/internal/applications/shareapp" - "go.openfort.xyz/shield/internal/infrastructure/authenticationmgr" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/authmdw" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/projecthdl" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/ratelimitermdw" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/requestmdw" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/responsemdw" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/sharehdl" "go.openfort.xyz/shield/pkg/logger" ) diff --git a/internal/infrastructure/handlers/rest/sharehdl/errors.go b/internal/adapters/handlers/rest/sharehdl/errors.go similarity index 93% rename from internal/infrastructure/handlers/rest/sharehdl/errors.go rename to internal/adapters/handlers/rest/sharehdl/errors.go index 613798e..17bcb7b 100644 --- a/internal/infrastructure/handlers/rest/sharehdl/errors.go +++ b/internal/adapters/handlers/rest/sharehdl/errors.go @@ -3,8 +3,8 @@ package sharehdl import ( "errors" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/internal/applications/shareapp" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" ) func fromApplicationError(err error) *api.Error { diff --git a/internal/infrastructure/handlers/rest/sharehdl/handler.go b/internal/adapters/handlers/rest/sharehdl/handler.go similarity index 98% rename from internal/infrastructure/handlers/rest/sharehdl/handler.go rename to internal/adapters/handlers/rest/sharehdl/handler.go index f75a083..1a1f442 100644 --- a/internal/infrastructure/handlers/rest/sharehdl/handler.go +++ b/internal/adapters/handlers/rest/sharehdl/handler.go @@ -6,8 +6,8 @@ import ( "log/slog" "net/http" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/internal/applications/shareapp" - "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" "go.openfort.xyz/shield/pkg/logger" ) diff --git a/internal/infrastructure/handlers/rest/sharehdl/parser.go b/internal/adapters/handlers/rest/sharehdl/parser.go similarity index 100% rename from internal/infrastructure/handlers/rest/sharehdl/parser.go rename to internal/adapters/handlers/rest/sharehdl/parser.go diff --git a/internal/infrastructure/handlers/rest/sharehdl/types.go b/internal/adapters/handlers/rest/sharehdl/types.go similarity index 100% rename from internal/infrastructure/handlers/rest/sharehdl/types.go rename to internal/adapters/handlers/rest/sharehdl/types.go diff --git a/internal/infrastructure/handlers/rest/sharehdl/validator.go b/internal/adapters/handlers/rest/sharehdl/validator.go similarity index 93% rename from internal/infrastructure/handlers/rest/sharehdl/validator.go rename to internal/adapters/handlers/rest/sharehdl/validator.go index 89ec6fb..8f843cc 100644 --- a/internal/infrastructure/handlers/rest/sharehdl/validator.go +++ b/internal/adapters/handlers/rest/sharehdl/validator.go @@ -1,6 +1,6 @@ package sharehdl -import "go.openfort.xyz/shield/internal/infrastructure/handlers/rest/api" +import "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" type validator struct { } diff --git a/internal/infrastructure/repositories/bunt/client.go b/internal/adapters/repositories/bunt/client.go similarity index 100% rename from internal/infrastructure/repositories/bunt/client.go rename to internal/adapters/repositories/bunt/client.go diff --git a/internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go b/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go similarity index 81% rename from internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go rename to internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go index 81fd99d..f4291ad 100644 --- a/internal/infrastructure/repositories/bunt/encryptionpartsrepo/repo.go +++ b/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go @@ -4,9 +4,9 @@ import ( "context" "errors" "github.com/tidwall/buntdb" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/bunt" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/infrastructure/repositories/bunt" "go.openfort.xyz/shield/pkg/logger" "log/slog" ) @@ -34,14 +34,14 @@ func (r *repository) Get(ctx context.Context, sessionId string) (string, error) }) if err != nil { if errors.Is(err, buntdb.ErrNotFound) { - return "", domain.ErrEncryptionPartNotFound + return "", domainErrors.ErrEncryptionPartNotFound } r.logger.ErrorContext(ctx, "error getting encryption part", logger.Error(err)) return "", err } if part == "" { - return "", domain.ErrEncryptionPartNotFound + return "", domainErrors.ErrEncryptionPartNotFound } return part, nil @@ -51,7 +51,7 @@ func (r *repository) Set(ctx context.Context, sessionId, part string) error { return r.db.Update(func(tx *buntdb.Tx) error { _, _, err := tx.Set(sessionId, part, nil) if errors.Is(err, buntdb.ErrIndexExists) { - return domain.ErrEncryptionPartAlreadyExists + return domainErrors.ErrEncryptionPartAlreadyExists } r.logger.ErrorContext(ctx, "error setting encryption part", logger.Error(err)) return err @@ -62,7 +62,7 @@ func (r *repository) Delete(ctx context.Context, sessionId string) error { return r.db.Update(func(tx *buntdb.Tx) error { _, err := tx.Delete(sessionId) if errors.Is(err, buntdb.ErrNotFound) { - return domain.ErrEncryptionPartNotFound + return domainErrors.ErrEncryptionPartNotFound } r.logger.ErrorContext(ctx, "error deleting encryption part", logger.Error(err)) return err diff --git a/internal/infrastructure/repositories/mocks/encryptionpartsmockrepo/repo.go b/internal/adapters/repositories/mocks/encryptionpartsmockrepo/repo.go similarity index 100% rename from internal/infrastructure/repositories/mocks/encryptionpartsmockrepo/repo.go rename to internal/adapters/repositories/mocks/encryptionpartsmockrepo/repo.go diff --git a/internal/infrastructure/repositories/mocks/projectmockrepo/repo.go b/internal/adapters/repositories/mocks/projectmockrepo/repo.go similarity index 100% rename from internal/infrastructure/repositories/mocks/projectmockrepo/repo.go rename to internal/adapters/repositories/mocks/projectmockrepo/repo.go diff --git a/internal/infrastructure/repositories/mocks/providermockrepo/repo.go b/internal/adapters/repositories/mocks/providermockrepo/repo.go similarity index 89% rename from internal/infrastructure/repositories/mocks/providermockrepo/repo.go rename to internal/adapters/repositories/mocks/providermockrepo/repo.go index 877f2d3..d2b560b 100644 --- a/internal/infrastructure/repositories/mocks/providermockrepo/repo.go +++ b/internal/adapters/repositories/mocks/providermockrepo/repo.go @@ -27,6 +27,15 @@ func (m *MockProviderRepository) GetByProjectAndType(ctx context.Context, projec return args.Get(0).(*provider.Provider), args.Error(1) } +func (m *MockProviderRepository) GetByAPIKeyAndType(ctx context.Context, apiKey string, providerType provider.Type) (*provider.Provider, error) { + args := m.Mock.Called(ctx, apiKey, providerType) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*provider.Provider), args.Error(1) +} + func (m *MockProviderRepository) Get(ctx context.Context, id string) (*provider.Provider, error) { args := m.Mock.Called(ctx, id) if args.Get(0) == nil { diff --git a/internal/infrastructure/repositories/mocks/sharemockrepo/repo.go b/internal/adapters/repositories/mocks/sharemockrepo/repo.go similarity index 100% rename from internal/infrastructure/repositories/mocks/sharemockrepo/repo.go rename to internal/adapters/repositories/mocks/sharemockrepo/repo.go diff --git a/internal/infrastructure/repositories/mocks/usermockedrepo/repo.go b/internal/adapters/repositories/mocks/usermockedrepo/repo.go similarity index 100% rename from internal/infrastructure/repositories/mocks/usermockedrepo/repo.go rename to internal/adapters/repositories/mocks/usermockedrepo/repo.go diff --git a/internal/infrastructure/repositories/sql/client.go b/internal/adapters/repositories/sql/client.go similarity index 100% rename from internal/infrastructure/repositories/sql/client.go rename to internal/adapters/repositories/sql/client.go diff --git a/internal/infrastructure/repositories/sql/config.go b/internal/adapters/repositories/sql/config.go similarity index 100% rename from internal/infrastructure/repositories/sql/config.go rename to internal/adapters/repositories/sql/config.go diff --git a/internal/infrastructure/repositories/sql/errors.go b/internal/adapters/repositories/sql/errors.go similarity index 100% rename from internal/infrastructure/repositories/sql/errors.go rename to internal/adapters/repositories/sql/errors.go diff --git a/internal/infrastructure/repositories/sql/migrations/20240319132302_init.sql b/internal/adapters/repositories/sql/migrations/20240319132302_init.sql similarity index 100% rename from internal/infrastructure/repositories/sql/migrations/20240319132302_init.sql rename to internal/adapters/repositories/sql/migrations/20240319132302_init.sql diff --git a/internal/infrastructure/repositories/sql/migrations/20240328110831_custom_domains.sql b/internal/adapters/repositories/sql/migrations/20240328110831_custom_domains.sql similarity index 100% rename from internal/infrastructure/repositories/sql/migrations/20240328110831_custom_domains.sql rename to internal/adapters/repositories/sql/migrations/20240328110831_custom_domains.sql diff --git a/internal/infrastructure/repositories/sql/migrations/20240404172915_entropy.sql b/internal/adapters/repositories/sql/migrations/20240404172915_entropy.sql similarity index 100% rename from internal/infrastructure/repositories/sql/migrations/20240404172915_entropy.sql rename to internal/adapters/repositories/sql/migrations/20240404172915_entropy.sql diff --git a/internal/infrastructure/repositories/sql/migrations/20240405095925_encryption_parts.sql b/internal/adapters/repositories/sql/migrations/20240405095925_encryption_parts.sql similarity index 100% rename from internal/infrastructure/repositories/sql/migrations/20240405095925_encryption_parts.sql rename to internal/adapters/repositories/sql/migrations/20240405095925_encryption_parts.sql diff --git a/internal/infrastructure/repositories/sql/migrations/20240411120851_add_pem_certs_for_custom_auth.sql b/internal/adapters/repositories/sql/migrations/20240411120851_add_pem_certs_for_custom_auth.sql similarity index 100% rename from internal/infrastructure/repositories/sql/migrations/20240411120851_add_pem_certs_for_custom_auth.sql rename to internal/adapters/repositories/sql/migrations/20240411120851_add_pem_certs_for_custom_auth.sql diff --git a/internal/infrastructure/repositories/sql/projectrepo/parser.go b/internal/adapters/repositories/sql/projectrepo/parser.go similarity index 100% rename from internal/infrastructure/repositories/sql/projectrepo/parser.go rename to internal/adapters/repositories/sql/projectrepo/parser.go diff --git a/internal/infrastructure/repositories/sql/projectrepo/repo.go b/internal/adapters/repositories/sql/projectrepo/repo.go similarity index 91% rename from internal/infrastructure/repositories/sql/projectrepo/repo.go rename to internal/adapters/repositories/sql/projectrepo/repo.go index 5a775d6..afb1464 100644 --- a/internal/infrastructure/repositories/sql/projectrepo/repo.go +++ b/internal/adapters/repositories/sql/projectrepo/repo.go @@ -3,14 +3,14 @@ package projectrepo import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" "github.com/google/uuid" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -54,7 +54,7 @@ func (r *repository) Get(ctx context.Context, projectID string) (*project.Projec err := r.db.Where("id = ?", projectID).First(dbProj).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrProjectNotFound + return nil, domainErrors.ErrProjectNotFound } r.logger.ErrorContext(ctx, "error getting project", logger.Error(err)) return nil, err @@ -70,7 +70,7 @@ func (r *repository) GetByAPIKey(ctx context.Context, apiKey string) (*project.P err := r.db.Where("api_key = ?", apiKey).First(dbProj).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrProjectNotFound + return nil, domainErrors.ErrProjectNotFound } r.logger.ErrorContext(ctx, "error getting project", logger.Error(err)) return nil, err @@ -98,7 +98,7 @@ func (r *repository) GetEncryptionPart(ctx context.Context, projectID string) (s err := r.db.Where("project_id = ?", projectID).First(encryptionPart).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return "", domain.ErrEncryptionPartNotFound + return "", domainErrors.ErrEncryptionPartNotFound } r.logger.ErrorContext(ctx, "error getting encryption part", logger.Error(err)) return "", err diff --git a/internal/infrastructure/repositories/sql/projectrepo/types.go b/internal/adapters/repositories/sql/projectrepo/types.go similarity index 100% rename from internal/infrastructure/repositories/sql/projectrepo/types.go rename to internal/adapters/repositories/sql/projectrepo/types.go diff --git a/internal/infrastructure/repositories/sql/providerrepo/parser.go b/internal/adapters/repositories/sql/providerrepo/parser.go similarity index 100% rename from internal/infrastructure/repositories/sql/providerrepo/parser.go rename to internal/adapters/repositories/sql/providerrepo/parser.go diff --git a/internal/infrastructure/repositories/sql/providerrepo/repo.go b/internal/adapters/repositories/sql/providerrepo/repo.go similarity index 83% rename from internal/infrastructure/repositories/sql/providerrepo/repo.go rename to internal/adapters/repositories/sql/providerrepo/repo.go index 7c9cb67..32ce52c 100644 --- a/internal/infrastructure/repositories/sql/providerrepo/repo.go +++ b/internal/adapters/repositories/sql/providerrepo/repo.go @@ -3,13 +3,13 @@ package providerrepo import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" "github.com/google/uuid" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -54,7 +54,7 @@ func (r *repository) GetByProjectAndType(ctx context.Context, projectID string, err := r.db.Preload("Custom").Preload("Openfort").Where("project_id = ? AND type = ?", projectID, r.parser.mapProviderTypeToDatabase[providerType]).First(&dbProv).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrProviderNotFound + return nil, domainErrors.ErrProviderNotFound } r.logger.ErrorContext(ctx, "error getting provider", logger.Error(err)) return nil, err @@ -63,6 +63,23 @@ func (r *repository) GetByProjectAndType(ctx context.Context, projectID string, return r.parser.toDomainProvider(dbProv), nil } +func (r *repository) GetByAPIKeyAndType(ctx context.Context, apiKey string, providerType provider.Type) (*provider.Provider, error) { + r.logger.InfoContext(ctx, "getting provider", slog.String("api_key", apiKey), slog.String("provider_type", providerType.String())) + + dbProv := Provider{} + err := r.db.Preload("Custom").Preload("Openfort").Joins("INNER JOIN shld_projects sp on shld_providers.project_id = sp.id").Where("sp.api_key = ? AND shld_providers.type = ?", apiKey, r.parser.mapProviderTypeToDatabase[providerType]).First(&dbProv).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, domainErrors.ErrProviderNotFound + } + r.logger.ErrorContext(ctx, "error getting provider", logger.Error(err)) + return nil, err + + } + + return r.parser.toDomainProvider(dbProv), nil +} + func (r *repository) Get(ctx context.Context, id string) (*provider.Provider, error) { r.logger.InfoContext(ctx, "getting provider", slog.String("provider_id", id)) @@ -70,7 +87,7 @@ func (r *repository) Get(ctx context.Context, id string) (*provider.Provider, er err := r.db.Preload("Custom").Preload("Openfort").Where("id = ?", id).First(&dbProv).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrProviderNotFound + return nil, domainErrors.ErrProviderNotFound } r.logger.ErrorContext(ctx, "error getting provider", logger.Error(err)) return nil, err @@ -107,7 +124,7 @@ func (r *repository) Delete(ctx context.Context, providerID string) error { } if cmd.RowsAffected == 0 { - return domain.ErrProviderNotFound + return domainErrors.ErrProviderNotFound } return nil @@ -133,7 +150,7 @@ func (r *repository) GetCustom(ctx context.Context, providerID string) (*provide err := r.db.Where("provider_id = ?", providerID).First(dbProv).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrProviderNotFound + return nil, domainErrors.ErrProviderNotFound } r.logger.ErrorContext(ctx, "error getting custom provider", logger.Error(err)) return nil, err @@ -175,7 +192,7 @@ func (r *repository) GetOpenfort(ctx context.Context, providerID string) (*provi err := r.db.Where("provider_id = ?", providerID).First(dbProv).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrProviderNotFound + return nil, domainErrors.ErrProviderNotFound } r.logger.ErrorContext(ctx, "error getting openfort provider", logger.Error(err)) return nil, err diff --git a/internal/infrastructure/repositories/sql/providerrepo/types.go b/internal/adapters/repositories/sql/providerrepo/types.go similarity index 100% rename from internal/infrastructure/repositories/sql/providerrepo/types.go rename to internal/adapters/repositories/sql/providerrepo/types.go diff --git a/internal/infrastructure/repositories/sql/sharerepo/parser.go b/internal/adapters/repositories/sql/sharerepo/parser.go similarity index 100% rename from internal/infrastructure/repositories/sql/sharerepo/parser.go rename to internal/adapters/repositories/sql/sharerepo/parser.go diff --git a/internal/infrastructure/repositories/sql/sharerepo/repo.go b/internal/adapters/repositories/sql/sharerepo/repo.go similarity index 94% rename from internal/infrastructure/repositories/sql/sharerepo/repo.go rename to internal/adapters/repositories/sql/sharerepo/repo.go index 2c286bb..6da1931 100644 --- a/internal/infrastructure/repositories/sql/sharerepo/repo.go +++ b/internal/adapters/repositories/sql/sharerepo/repo.go @@ -3,13 +3,13 @@ package sharerepo import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" "github.com/google/uuid" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -54,7 +54,7 @@ func (r *repository) GetByUserID(ctx context.Context, userID string) (*share.Sha err := r.db.Where("user_id = ?", userID).First(dbShr).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrShareNotFound + return nil, domainErrors.ErrShareNotFound } r.logger.ErrorContext(ctx, "error getting share", logger.Error(err)) return nil, err diff --git a/internal/infrastructure/repositories/sql/sharerepo/types.go b/internal/adapters/repositories/sql/sharerepo/types.go similarity index 100% rename from internal/infrastructure/repositories/sql/sharerepo/types.go rename to internal/adapters/repositories/sql/sharerepo/types.go diff --git a/internal/infrastructure/repositories/sql/userrepo/options.go b/internal/adapters/repositories/sql/userrepo/options.go similarity index 100% rename from internal/infrastructure/repositories/sql/userrepo/options.go rename to internal/adapters/repositories/sql/userrepo/options.go diff --git a/internal/infrastructure/repositories/sql/userrepo/parser.go b/internal/adapters/repositories/sql/userrepo/parser.go similarity index 100% rename from internal/infrastructure/repositories/sql/userrepo/parser.go rename to internal/adapters/repositories/sql/userrepo/parser.go diff --git a/internal/infrastructure/repositories/sql/userrepo/repo.go b/internal/adapters/repositories/sql/userrepo/repo.go similarity index 93% rename from internal/infrastructure/repositories/sql/userrepo/repo.go rename to internal/adapters/repositories/sql/userrepo/repo.go index 28b9bd5..0009077 100644 --- a/internal/infrastructure/repositories/sql/userrepo/repo.go +++ b/internal/adapters/repositories/sql/userrepo/repo.go @@ -3,13 +3,13 @@ package userrepo import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" "github.com/google/uuid" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/user" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/infrastructure/repositories/sql" "go.openfort.xyz/shield/pkg/logger" "gorm.io/gorm" ) @@ -54,7 +54,7 @@ func (r *repository) Get(ctx context.Context, userID string) (*user.User, error) err := r.db.Where("id = ?", userID).First(dbUsr).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrUserNotFound + return nil, domainErrors.ErrUserNotFound } r.logger.ErrorContext(ctx, "error getting user", logger.Error(err)) return nil, err diff --git a/internal/infrastructure/repositories/sql/userrepo/types.go b/internal/adapters/repositories/sql/userrepo/types.go similarity index 100% rename from internal/infrastructure/repositories/sql/userrepo/types.go rename to internal/adapters/repositories/sql/userrepo/types.go diff --git a/internal/applications/projectapp/app.go b/internal/applications/projectapp/app.go index ef0eea1..d377855 100644 --- a/internal/applications/projectapp/app.go +++ b/internal/applications/projectapp/app.go @@ -3,9 +3,9 @@ package projectapp import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" - "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/domain/share" @@ -93,7 +93,7 @@ func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderO var providers []*provider.Provider if cfg.openfortPublishableKey != nil { prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeOpenfort) - if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrProviderNotFound) { a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, fromDomainError(err) } @@ -109,7 +109,7 @@ func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderO if cfg.jwkURL != nil { prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeCustom) - if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrProviderNotFound) { a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, fromDomainError(err) } @@ -121,7 +121,7 @@ func (a *ProjectApplication) AddProviders(ctx context.Context, opts ...ProviderO if cfg.pem != nil { prov, err := a.providerRepo.GetByProjectAndType(ctx, projectID, provider.TypeCustom) - if err != nil && !errors.Is(err, domain.ErrProviderNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrProviderNotFound) { a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, fromDomainError(err) } @@ -320,7 +320,7 @@ func (a *ProjectApplication) RegisterEncryptionKey(ctx context.Context) (string, projectID := contexter.GetProjectID(ctx) ep, err := a.projectRepo.GetEncryptionPart(ctx, projectID) - if err != nil && !errors.Is(err, domain.ErrEncryptionPartNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrEncryptionPartNotFound) { a.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) return "", fromDomainError(err) } diff --git a/internal/applications/projectapp/app_test.go b/internal/applications/projectapp/app_test.go index f80feac..a4727d9 100644 --- a/internal/applications/projectapp/app_test.go +++ b/internal/applications/projectapp/app_test.go @@ -5,15 +5,15 @@ import ( "errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/providermockrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/sharemockrepo" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/services/projectsvc" "go.openfort.xyz/shield/internal/core/services/providersvc" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/projectmockrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/providermockrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/sharemockrepo" "go.openfort.xyz/shield/pkg/contexter" "go.openfort.xyz/shield/pkg/cypher" "testing" @@ -158,7 +158,7 @@ func TestProjectApplication_GetProject(t *testing.T) { name: "project not found", mock: func() { projectRepo.ExpectedCalls = nil - projectRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrProjectNotFound) + projectRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrProjectNotFound) }, wantProj: nil, wantErr: ErrProjectNotFound, @@ -213,8 +213,8 @@ func TestProjectApplication_AddProviders(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil providerRepo.ExpectedCalls = nil - providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeOpenfort).Return(nil, domain.ErrProviderNotFound) - providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeOpenfort).Return(nil, domainErrors.ErrProviderNotFound) + providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) providerRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) providerRepo.On("CreateOpenfort", mock.Anything, mock.AnythingOfType("*provider.OpenfortConfig")).Return(nil) providerRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(nil) @@ -230,7 +230,7 @@ func TestProjectApplication_AddProviders(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil providerRepo.ExpectedCalls = nil - providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) providerRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) providerRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(nil) }, @@ -337,8 +337,8 @@ func TestProjectApplication_AddProviders(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil providerRepo.ExpectedCalls = nil - providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeOpenfort).Return(nil, domain.ErrProviderNotFound) - providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeOpenfort).Return(nil, domainErrors.ErrProviderNotFound) + providerRepo.On("GetByProjectAndType", mock.Anything, mock.Anything, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) providerRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(errors.New("repository error")) }, }, @@ -471,7 +471,7 @@ func TestProjectApplication_GetProviderDetail(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil providerRepo.ExpectedCalls = nil - providerRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrProviderNotFound) + providerRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrProviderNotFound) }, }, { @@ -591,7 +591,7 @@ func TestProjectApplication_UpdateProvider(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil providerRepo.ExpectedCalls = nil - providerRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrProviderNotFound) + providerRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrProviderNotFound) }, }, { @@ -765,7 +765,7 @@ func TestProjectApplication_RemoveProvider(t *testing.T) { mock: func() { projectRepo.ExpectedCalls = nil providerRepo.ExpectedCalls = nil - providerRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrProviderNotFound) + providerRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrProviderNotFound) }, }, { @@ -880,7 +880,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { externalPart: externalPart, mock: func() { projectRepo.ExpectedCalls = nil - projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", domain.ErrEncryptionPartNotFound) + projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", domainErrors.ErrEncryptionPartNotFound) }, wantErr: ErrEncryptionNotConfigured, }, @@ -957,7 +957,7 @@ func TestProjectApplication_RegisterEncryptionKey(t *testing.T) { wantErr: nil, mock: func() { projectRepo.ExpectedCalls = nil - projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domain.ErrEncryptionPartNotFound) + projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) projectRepo.On("SetEncryptionPart", mock.Anything, "project_id", mock.Anything).Return(nil) }, }, @@ -982,7 +982,7 @@ func TestProjectApplication_RegisterEncryptionKey(t *testing.T) { wantErr: ErrInternal, mock: func() { projectRepo.ExpectedCalls = nil - projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domain.ErrEncryptionPartNotFound) + projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) projectRepo.On("SetEncryptionPart", mock.Anything, "project_id", mock.Anything).Return(errors.New("repository error")) }, }, diff --git a/internal/applications/projectapp/errors.go b/internal/applications/projectapp/errors.go index 7da3d7b..1685240 100644 --- a/internal/applications/projectapp/errors.go +++ b/internal/applications/projectapp/errors.go @@ -2,8 +2,7 @@ package projectapp import ( "errors" - - "go.openfort.xyz/shield/internal/core/domain" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" ) var ( @@ -26,27 +25,27 @@ func fromDomainError(err error) error { if err == nil { return nil } - if errors.Is(err, domain.ErrProjectNotFound) { + if errors.Is(err, domainErrors.ErrProjectNotFound) { return ErrProjectNotFound } - if errors.Is(err, domain.ErrInvalidProviderConfig) { + if errors.Is(err, domainErrors.ErrInvalidProviderConfig) { return ErrInvalidProviderConfig } - if errors.Is(err, domain.ErrUnknownProviderType) { + if errors.Is(err, domainErrors.ErrUnknownProviderType) { return ErrUnknownProviderType } - if errors.Is(err, domain.ErrProviderAlreadyExists) { + if errors.Is(err, domainErrors.ErrProviderAlreadyExists) { return ErrProviderAlreadyExists } - if errors.Is(err, domain.ErrProviderNotFound) { + if errors.Is(err, domainErrors.ErrProviderNotFound) { return ErrProviderNotFound } - if errors.Is(err, domain.ErrEncryptionPartNotFound) { + if errors.Is(err, domainErrors.ErrEncryptionPartNotFound) { return ErrEncryptionNotConfigured } diff --git a/internal/applications/shareapp/app_test.go b/internal/applications/shareapp/app_test.go index e863dd0..a4f1a79 100644 --- a/internal/applications/shareapp/app_test.go +++ b/internal/applications/shareapp/app_test.go @@ -5,11 +5,11 @@ import ( "errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/sharemockrepo" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/services/sharesvc" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/projectmockrepo" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/sharemockrepo" "go.openfort.xyz/shield/pkg/contexter" "go.openfort.xyz/shield/pkg/cypher" "testing" @@ -102,7 +102,7 @@ func TestShareApplication_GetShare(t *testing.T) { shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) - projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domain.ErrEncryptionPartNotFound) + projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) }, opts: []Option{ WithEncryptionPart(externalPart), @@ -140,7 +140,7 @@ func TestShareApplication_GetShare(t *testing.T) { mock: func() { shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(nil, domain.ErrShareNotFound) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(nil, domainErrors.ErrShareNotFound) }, }, { @@ -172,7 +172,7 @@ func TestShareApplication_GetShare(t *testing.T) { shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) - projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domain.ErrEncryptionPartNotFound) + projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) }, opts: []Option{ WithEncryptionPart(externalPart), @@ -240,7 +240,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { share: plainShare, mock: func() { shareRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domain.ErrShareNotFound) + shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) shareRepo.On("Create", mock.Anything, plainShare).Return(nil) }, }, @@ -251,7 +251,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { mock: func() { shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domain.ErrShareNotFound) + shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) shareRepo.On("Create", mock.Anything, encryptedShare).Return(nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) }, @@ -265,7 +265,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { share: encryptedShare, mock: func() { shareRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domain.ErrShareNotFound) + shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) }, }, { @@ -275,8 +275,8 @@ func TestShareApplication_RegisterShare(t *testing.T) { mock: func() { shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domain.ErrShareNotFound) - projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domain.ErrEncryptionPartNotFound) + shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) + projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) }, opts: []Option{ WithEncryptionPart(externalPart), @@ -289,7 +289,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { mock: func() { shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domain.ErrShareNotFound) + shareRepo.On("GetByUserID", mock.Anything, mock.Anything, mock.Anything).Return(nil, domainErrors.ErrShareNotFound) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) }, opts: []Option{ @@ -353,7 +353,7 @@ func TestShareApplication_DeleteShare(t *testing.T) { wantErr: ErrShareNotFound, mock: func() { shareRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(nil, domain.ErrShareNotFound) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(nil, domainErrors.ErrShareNotFound) }, }, { diff --git a/internal/applications/shareapp/errors.go b/internal/applications/shareapp/errors.go index 5c1a666..b24c123 100644 --- a/internal/applications/shareapp/errors.go +++ b/internal/applications/shareapp/errors.go @@ -2,8 +2,7 @@ package shareapp import ( "errors" - - "go.openfort.xyz/shield/internal/core/domain" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" ) var ( @@ -19,19 +18,19 @@ var ( ) func fromDomainError(err error) error { - if errors.Is(err, domain.ErrShareNotFound) { + if errors.Is(err, domainErrors.ErrShareNotFound) { return ErrShareNotFound } - if errors.Is(err, domain.ErrShareAlreadyExists) { + if errors.Is(err, domainErrors.ErrShareAlreadyExists) { return ErrShareAlreadyExists } - if errors.Is(err, domain.ErrEncryptionPartRequired) { + if errors.Is(err, domainErrors.ErrEncryptionPartRequired) { return ErrEncryptionPartRequired } - if errors.Is(err, domain.ErrEncryptionPartNotFound) { + if errors.Is(err, domainErrors.ErrEncryptionPartNotFound) { return ErrEncryptionNotConfigured } diff --git a/internal/core/domain/authentication/authentication.go b/internal/core/domain/authentication/authentication.go new file mode 100644 index 0000000..e07803e --- /dev/null +++ b/internal/core/domain/authentication/authentication.go @@ -0,0 +1,6 @@ +package authentication + +type Authentication struct { + UserID string + ProjectID string +} diff --git a/internal/core/domain/errors.go b/internal/core/domain/errors.go deleted file mode 100644 index f78bad7..0000000 --- a/internal/core/domain/errors.go +++ /dev/null @@ -1,26 +0,0 @@ -package domain - -import "errors" - -var ( - // Project errors - ErrProjectNotFound = errors.New("project not found") - ErrEncryptionPartNotFound = errors.New("encryption part not found") - ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") - ErrEncryptionPartRequired = errors.New("encryption part is required") - - // Provider errors - ErrInvalidProviderConfig = errors.New("invalid provider config") - ErrUnknownProviderType = errors.New("unknown provider type") - ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project") - ErrProviderNotFound = errors.New("custom authentication not found") - - // Share errors - ErrShareNotFound = errors.New("share not found") - ErrShareAlreadyExists = errors.New("share already exists") - - // User errors - ErrUserNotFound = errors.New("user not found") - ErrExternalUserNotFound = errors.New("external user not found") - ErrExternalUserAlreadyExists = errors.New("external user already exists") -) diff --git a/internal/core/domain/errors/project.go b/internal/core/domain/errors/project.go new file mode 100644 index 0000000..b6601e0 --- /dev/null +++ b/internal/core/domain/errors/project.go @@ -0,0 +1,10 @@ +package errors + +import "errors" + +var ( + ErrProjectNotFound = errors.New("project not found") + ErrEncryptionPartNotFound = errors.New("encryption part not found") + ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") + ErrEncryptionPartRequired = errors.New("encryption part is required") +) diff --git a/internal/core/domain/errors/provider.go b/internal/core/domain/errors/provider.go new file mode 100644 index 0000000..c770ede --- /dev/null +++ b/internal/core/domain/errors/provider.go @@ -0,0 +1,10 @@ +package errors + +import "errors" + +var ( + ErrInvalidProviderConfig = errors.New("invalid provider config") + ErrUnknownProviderType = errors.New("unknown provider type") + ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project") + ErrProviderNotFound = errors.New("custom authentication not found") +) diff --git a/internal/core/domain/errors/share.go b/internal/core/domain/errors/share.go new file mode 100644 index 0000000..217ae0c --- /dev/null +++ b/internal/core/domain/errors/share.go @@ -0,0 +1,8 @@ +package errors + +import "errors" + +var ( + ErrShareNotFound = errors.New("share not found") + ErrShareAlreadyExists = errors.New("share already exists") +) diff --git a/internal/core/domain/errors/user.go b/internal/core/domain/errors/user.go new file mode 100644 index 0000000..555dc1b --- /dev/null +++ b/internal/core/domain/errors/user.go @@ -0,0 +1,9 @@ +package errors + +import "errors" + +var ( + ErrUserNotFound = errors.New("user not found") + ErrExternalUserNotFound = errors.New("external user not found") + ErrExternalUserAlreadyExists = errors.New("external user already exists") +) diff --git a/internal/core/ports/authentication/apisecret.go b/internal/core/ports/authentication/apisecret.go deleted file mode 100644 index b76d00d..0000000 --- a/internal/core/ports/authentication/apisecret.go +++ /dev/null @@ -1,7 +0,0 @@ -package authentication - -import "context" - -type APISecretAuthenticator interface { - Authenticate(ctx context.Context, apiKey, apiSecret string) (projectID string, err error) -} diff --git a/internal/core/ports/authentication/user.go b/internal/core/ports/authentication/user.go deleted file mode 100644 index ef075e6..0000000 --- a/internal/core/ports/authentication/user.go +++ /dev/null @@ -1,34 +0,0 @@ -package authentication - -import ( - "context" - - "go.openfort.xyz/shield/internal/core/domain/provider" -) - -type UserAuthenticator interface { - Authenticate(ctx context.Context, apiKey, token string, providerType provider.Type, opts ...CustomOption) (*Authentication, error) -} - -type CustomOption func(*CustomOptions) -type CustomOptions struct { - OpenfortProvider *string - OpenfortTokenType *string -} - -func WithOpenfortProvider(value string) CustomOption { - return func(opts *CustomOptions) { - opts.OpenfortProvider = &value - } -} - -func WithOpenfortTokenType(value string) CustomOption { - return func(opts *CustomOptions) { - opts.OpenfortTokenType = &value - } -} - -type Authentication struct { - UserID string - ProjectID string -} diff --git a/internal/core/ports/factories/authentication.go b/internal/core/ports/factories/authentication.go new file mode 100644 index 0000000..237e82b --- /dev/null +++ b/internal/core/ports/factories/authentication.go @@ -0,0 +1,15 @@ +package factories + +import ( + "context" + "go.openfort.xyz/shield/internal/core/domain/authentication" +) + +type AuthenticationFactory interface { + CreateProjectAuthenticator(apiKey, apiSecret string) Authenticator + CreateUserAuthenticator(apiKey, token string, identityFactory Identity) Authenticator +} + +type Authenticator interface { + Authenticate(ctx context.Context) (*authentication.Authentication, error) +} diff --git a/internal/core/ports/factories/encryption.go b/internal/core/ports/factories/encryption.go new file mode 100644 index 0000000..ead4c0f --- /dev/null +++ b/internal/core/ports/factories/encryption.go @@ -0,0 +1,75 @@ +package factories + +import ( + "context" + "errors" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/pkg/cypher" +) + +type EncryptionFactory interface { + CreateEncryptionStrategy() EncryptionStrategy +} + +type EncryptionStrategy interface { + Encrypt(ctx context.Context, plain string) (string, error) + Decrypt(ctx context.Context, encrypted string) (string, error) +} + +type EncryptionKeyBuilder interface { + SetEncryptionPart(ctx context.Context, part string) EncryptionKeyBuilder + SetSessionPart(ctx context.Context, sessionID string) (EncryptionKeyBuilder, error) + SetDatabasePart(ctx context.Context, projectID string) (EncryptionKeyBuilder, error) + Build(ctx context.Context) (string, error) +} + +type EncryptionKeyBuilderImpl struct { + projectPart string + databasePart string + encryptionPartsRepo repositories.EncryptionPartsRepository + projectRepo repositories.ProjectRepository +} + +func NewEncryptionKeyBuilder() EncryptionKeyBuilder { + return &EncryptionKeyBuilderImpl{ + projectPart: "", + databasePart: "", + } +} + +func (b *EncryptionKeyBuilderImpl) SetEncryptionPart(ctx context.Context, part string) EncryptionKeyBuilder { + b.projectPart = part + return b +} + +func (b *EncryptionKeyBuilderImpl) SetSessionPart(ctx context.Context, sessionID string) (EncryptionKeyBuilder, error) { + part, err := b.encryptionPartsRepo.Get(ctx, sessionID) + if err != nil { + return nil, err + } + + b.projectPart = part + return b, nil +} + +func (b *EncryptionKeyBuilderImpl) SetDatabasePart(ctx context.Context, projectID string) (EncryptionKeyBuilder, error) { + part, err := b.projectRepo.GetEncryptionPart(ctx, projectID) + if err != nil { + return nil, err + } + + b.databasePart = part + return b, nil +} + +func (b *EncryptionKeyBuilderImpl) Build(ctx context.Context) (string, error) { + if b.projectPart == "" { + return "", errors.New("project part is required") // TODO extract error + } + + if b.databasePart == "" { + return "", errors.New("database part is required") // TODO extract error + } + + return cypher.ReconstructEncryptionKey(b.projectPart, b.databasePart) +} diff --git a/internal/core/ports/factories/identity.go b/internal/core/ports/factories/identity.go new file mode 100644 index 0000000..7ad05d6 --- /dev/null +++ b/internal/core/ports/factories/identity.go @@ -0,0 +1,15 @@ +package factories + +import ( + "context" +) + +type IdentityFactory interface { + CreateCustomIdentity(ctx context.Context, apiKey string) (Identity, error) + CreateOpenfortIdentity(ctx context.Context, apiKey string, authenticationProvider, tokenType *string) (Identity, error) +} + +type Identity interface { + GetProviderID() string + Identify(ctx context.Context, token string) (string, error) +} diff --git a/internal/core/ports/providers/provider.go b/internal/core/ports/providers/provider.go deleted file mode 100644 index ce60d07..0000000 --- a/internal/core/ports/providers/provider.go +++ /dev/null @@ -1,26 +0,0 @@ -package providers - -import "context" - -type IdentityProvider interface { - GetProviderID() string - Identify(ctx context.Context, token string, opts ...CustomOption) (string, error) -} - -type CustomOption func(*CustomOptions) -type CustomOptions struct { - OpenfortProvider *string - OpenfortTokenType *string -} - -func WithOpenfortProvider(value string) CustomOption { - return func(opts *CustomOptions) { - opts.OpenfortProvider = &value - } -} - -func WithOpenfortTokenType(value string) CustomOption { - return func(opts *CustomOptions) { - opts.OpenfortTokenType = &value - } -} diff --git a/internal/core/ports/repositories/provider.go b/internal/core/ports/repositories/provider.go index a16157e..b10c3ee 100644 --- a/internal/core/ports/repositories/provider.go +++ b/internal/core/ports/repositories/provider.go @@ -10,6 +10,7 @@ type ProviderRepository interface { Create(ctx context.Context, prov *provider.Provider) error Get(ctx context.Context, id string) (*provider.Provider, error) GetByProjectAndType(ctx context.Context, projectID string, providerType provider.Type) (*provider.Provider, error) + GetByAPIKeyAndType(ctx context.Context, apiKey string, providerType provider.Type) (*provider.Provider, error) List(ctx context.Context, projectID string) ([]*provider.Provider, error) Delete(ctx context.Context, providerID string) error diff --git a/internal/core/ports/services/share.go b/internal/core/ports/services/share.go index 8d95063..32904b2 100644 --- a/internal/core/ports/services/share.go +++ b/internal/core/ports/services/share.go @@ -13,8 +13,7 @@ type ShareService interface { type ShareOption func(*ShareOptions) type ShareOptions struct { - EncryptionKey *string - EncryptionSession *string + EncryptionKey *string } func WithEncryptionKey(key string) ShareOption { @@ -22,9 +21,3 @@ func WithEncryptionKey(key string) ShareOption { o.EncryptionKey = &key } } - -func WithEncryptionSession(session string) ShareOption { - return func(o *ShareOptions) { - o.EncryptionSession = &session - } -} diff --git a/internal/core/ports/services/user.go b/internal/core/ports/services/user.go index 8a72df8..fd76ddc 100644 --- a/internal/core/ports/services/user.go +++ b/internal/core/ports/services/user.go @@ -7,8 +7,5 @@ import ( ) type UserService interface { - Create(ctx context.Context, projectID string) (*user.User, error) - Get(ctx context.Context, userID string) (*user.User, error) - GetByExternal(ctx context.Context, externalUserID, providerID string) (*user.User, error) - CreateExternal(ctx context.Context, projectID, userID, externalUserID, providerID string) (*user.ExternalUser, error) + GetOrCreate(ctx context.Context, projectID, externalUserID, providerID string) (*user.User, error) } diff --git a/internal/core/services/projectsvc/svc.go b/internal/core/services/projectsvc/svc.go index 9735f66..f40686f 100644 --- a/internal/core/services/projectsvc/svc.go +++ b/internal/core/services/projectsvc/svc.go @@ -3,10 +3,10 @@ package projectsvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" "github.com/google/uuid" - "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -58,14 +58,14 @@ func (s *service) Create(ctx context.Context, name string) (*project.Project, er func (s *service) SetEncryptionPart(ctx context.Context, projectID, part string) error { s.logger.InfoContext(ctx, "setting encryption part", slog.String("project_id", projectID)) ep, err := s.repo.GetEncryptionPart(ctx, projectID) - if err != nil && !errors.Is(err, domain.ErrEncryptionPartNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrEncryptionPartNotFound) { s.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) return err } if ep != "" { s.logger.Warn("encryption part already exists", slog.String("project_id", projectID)) - return domain.ErrEncryptionPartAlreadyExists + return domainErrors.ErrEncryptionPartAlreadyExists } err = s.repo.SetEncryptionPart(ctx, projectID, part) diff --git a/internal/core/services/projectsvc/svc_test.go b/internal/core/services/projectsvc/svc_test.go index 516dffa..62bc68c 100644 --- a/internal/core/services/projectsvc/svc_test.go +++ b/internal/core/services/projectsvc/svc_test.go @@ -3,11 +3,11 @@ package projectsvc import ( "context" "errors" - "go.openfort.xyz/shield/internal/core/domain" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "testing" "github.com/stretchr/testify/mock" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/projectmockrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" ) func TestService_Create(t *testing.T) { @@ -79,7 +79,7 @@ func TestService_SetEncryptionPart(t *testing.T) { wantErr: false, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetEncryptionPart", mock.Anything, testProjectID).Return("", domain.ErrEncryptionPartNotFound) + mockRepo.On("GetEncryptionPart", mock.Anything, testProjectID).Return("", domainErrors.ErrEncryptionPartNotFound) mockRepo.On("SetEncryptionPart", mock.Anything, testProjectID, testPart).Return(nil) }, }, @@ -90,7 +90,7 @@ func TestService_SetEncryptionPart(t *testing.T) { mockRepo.ExpectedCalls = nil mockRepo.On("GetEncryptionPart", mock.Anything, testProjectID).Return("test-encryption-part", nil) }, - err: domain.ErrEncryptionPartAlreadyExists, + err: domainErrors.ErrEncryptionPartAlreadyExists, }, { name: "repository error on get encryption part", @@ -105,7 +105,7 @@ func TestService_SetEncryptionPart(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetEncryptionPart", mock.Anything, testProjectID).Return("", domain.ErrEncryptionPartNotFound) + mockRepo.On("GetEncryptionPart", mock.Anything, testProjectID).Return("", domainErrors.ErrEncryptionPartNotFound) mockRepo.On("SetEncryptionPart", mock.Anything, testProjectID, testPart).Return(errors.New("repository error")) }, }, diff --git a/internal/core/services/providersvc/svc.go b/internal/core/services/providersvc/svc.go index 144ae2b..8ebf3e2 100644 --- a/internal/core/services/providersvc/svc.go +++ b/internal/core/services/providersvc/svc.go @@ -3,9 +3,9 @@ package providersvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" - "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -33,7 +33,7 @@ func (s *service) Configure(ctx context.Context, prov *provider.Provider) error case provider.TypeOpenfort: return s.configureOpenfortProvider(ctx, prov) default: - return domain.ErrUnknownProviderType + return domainErrors.ErrUnknownProviderType } } @@ -49,7 +49,7 @@ func (s *service) configureCustomProvider(ctx context.Context, prov *provider.Pr customAuth, ok := prov.Config.(*provider.CustomConfig) if !ok { s.logger.ErrorContext(ctx, "invalid custom provider config") - return domain.ErrInvalidProviderConfig + return domainErrors.ErrInvalidProviderConfig } customAuth.ProviderID = prov.ID @@ -80,7 +80,7 @@ func (s *service) configureOpenfortProvider(ctx context.Context, prov *provider. openfortAuth, ok := prov.Config.(*provider.OpenfortConfig) if !ok { s.logger.ErrorContext(ctx, "invalid openfort provider config") - return domain.ErrInvalidProviderConfig + return domainErrors.ErrInvalidProviderConfig } openfortAuth.ProviderID = prov.ID diff --git a/internal/core/services/providersvc/svc_test.go b/internal/core/services/providersvc/svc_test.go index 2957552..b5fc889 100644 --- a/internal/core/services/providersvc/svc_test.go +++ b/internal/core/services/providersvc/svc_test.go @@ -3,12 +3,12 @@ package providersvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "testing" "github.com/stretchr/testify/mock" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/providermockrepo" "go.openfort.xyz/shield/internal/core/domain/provider" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/providermockrepo" ) func TestConfigureProvider(t *testing.T) { @@ -65,7 +65,7 @@ func TestConfigureProvider(t *testing.T) { provider: customProvider, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) mockRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(nil) }, @@ -75,7 +75,7 @@ func TestConfigureProvider(t *testing.T) { provider: openfortProvider, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) mockRepo.On("CreateOpenfort", mock.Anything, mock.AnythingOfType("*provider.OpenfortConfig")).Return(nil) }, @@ -85,21 +85,21 @@ func TestConfigureProvider(t *testing.T) { provider: unknownProvider, wantErr: true, mock: func() {}, - err: domain.ErrUnknownProviderType, + err: domainErrors.ErrUnknownProviderType, }, { name: "invalid custom provider config", provider: fakeCustomProvider, wantErr: true, mock: func() {}, - err: domain.ErrInvalidProviderConfig, + err: domainErrors.ErrInvalidProviderConfig, }, { name: "invalid openfort provider config", provider: fakeOpenfortProvider, wantErr: true, mock: func() {}, - err: domain.ErrInvalidProviderConfig, + err: domainErrors.ErrInvalidProviderConfig, }, { name: "failed to create custom provider", @@ -107,7 +107,7 @@ func TestConfigureProvider(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(errors.New("repository error")) }, }, @@ -117,7 +117,7 @@ func TestConfigureProvider(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) mockRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(errors.New("repository error")) mockRepo.On("Delete", mock.Anything, mock.AnythingOfType("string")).Return(nil) @@ -129,7 +129,7 @@ func TestConfigureProvider(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeCustom).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) mockRepo.On("CreateCustom", mock.Anything, mock.AnythingOfType("*provider.CustomConfig")).Return(errors.New("repository error")) mockRepo.On("Delete", mock.Anything, mock.AnythingOfType("string")).Return(errors.New("repository error")) @@ -141,7 +141,7 @@ func TestConfigureProvider(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(errors.New("repository error")) }, }, @@ -151,7 +151,7 @@ func TestConfigureProvider(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) mockRepo.On("CreateOpenfort", mock.Anything, mock.AnythingOfType("*provider.OpenfortConfig")).Return(errors.New("repository error")) mockRepo.On("Delete", mock.Anything, mock.AnythingOfType("string")).Return(nil) @@ -163,7 +163,7 @@ func TestConfigureProvider(t *testing.T) { wantErr: true, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domain.ErrProviderNotFound) + mockRepo.On("GetByProjectAndType", mock.Anything, projectID, provider.TypeOpenfort).Return(nil, domainErrors.ErrProviderNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*provider.Provider")).Return(nil) mockRepo.On("CreateOpenfort", mock.Anything, mock.AnythingOfType("*provider.OpenfortConfig")).Return(errors.New("repository error")) mockRepo.On("Delete", mock.Anything, mock.AnythingOfType("string")).Return(errors.New("repository error")) diff --git a/internal/core/services/sharesvc/svc.go b/internal/core/services/sharesvc/svc.go index 07686ef..bb8d7ba 100644 --- a/internal/core/services/sharesvc/svc.go +++ b/internal/core/services/sharesvc/svc.go @@ -3,9 +3,9 @@ package sharesvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" - "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -31,14 +31,14 @@ func (s *service) Create(ctx context.Context, shr *share.Share, opts ...services s.logger.InfoContext(ctx, "creating share", slog.String("user_id", shr.UserID)) shrRepo, err := s.repo.GetByUserID(ctx, shr.UserID) - if err != nil && !errors.Is(err, domain.ErrShareNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrShareNotFound) { s.logger.ErrorContext(ctx, "failed to get share", logger.Error(err)) return err } if shrRepo != nil { s.logger.ErrorContext(ctx, "share already exists", slog.String("user_id", shr.UserID)) - return domain.ErrShareAlreadyExists + return domainErrors.ErrShareAlreadyExists } var o services.ShareOptions @@ -48,7 +48,7 @@ func (s *service) Create(ctx context.Context, shr *share.Share, opts ...services if shr.RequiresEncryption() { if o.EncryptionKey == nil { - return domain.ErrEncryptionPartRequired + return domainErrors.ErrEncryptionPartRequired } shr.Secret, err = cypher.Encrypt(shr.Secret, *o.EncryptionKey) diff --git a/internal/core/services/sharesvc/svc_test.go b/internal/core/services/sharesvc/svc_test.go index d116a1b..7b905b4 100644 --- a/internal/core/services/sharesvc/svc_test.go +++ b/internal/core/services/sharesvc/svc_test.go @@ -3,14 +3,14 @@ package sharesvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/pkg/cypher" "testing" "github.com/stretchr/testify/mock" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/sharemockrepo" "go.openfort.xyz/shield/internal/core/domain/share" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/sharemockrepo" ) func TestCreateShare(t *testing.T) { @@ -53,7 +53,7 @@ func TestCreateShare(t *testing.T) { share: testShare, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domain.ErrShareNotFound) + mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domainErrors.ErrShareNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*share.Share")).Return(nil) }, }, @@ -63,7 +63,7 @@ func TestCreateShare(t *testing.T) { wantErr: false, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domain.ErrShareNotFound) + mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domainErrors.ErrShareNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*share.Share")).Return(nil) }, opts: []services.ShareOption{ @@ -76,9 +76,9 @@ func TestCreateShare(t *testing.T) { share: testEncryptionShare, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domain.ErrShareNotFound) + mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domainErrors.ErrShareNotFound) }, - err: domain.ErrEncryptionPartRequired, + err: domainErrors.ErrEncryptionPartRequired, }, { name: "encryption error", @@ -86,7 +86,7 @@ func TestCreateShare(t *testing.T) { share: testEncryptionShare, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domain.ErrShareNotFound) + mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domainErrors.ErrShareNotFound) }, opts: []services.ShareOption{ services.WithEncryptionKey("invalid-key"), @@ -96,7 +96,7 @@ func TestCreateShare(t *testing.T) { name: "share already exists", wantErr: true, share: testShare, - err: domain.ErrShareAlreadyExists, + err: domainErrors.ErrShareAlreadyExists, mock: func() { mockRepo.ExpectedCalls = nil mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(&share.Share{}, nil) @@ -117,7 +117,7 @@ func TestCreateShare(t *testing.T) { share: testShare, mock: func() { mockRepo.ExpectedCalls = nil - mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domain.ErrShareNotFound) + mockRepo.On("GetByUserID", mock.Anything, testUserID).Return(nil, domainErrors.ErrShareNotFound) mockRepo.On("Create", mock.Anything, mock.AnythingOfType("*share.Share")).Return(errors.New("repository error")) }, }, diff --git a/internal/core/services/usersvc/svc.go b/internal/core/services/usersvc/svc.go index 7fb8125..f2a6593 100644 --- a/internal/core/services/usersvc/svc.go +++ b/internal/core/services/usersvc/svc.go @@ -3,9 +3,9 @@ package usersvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" - "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/user" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -26,33 +26,48 @@ func New(repo repositories.UserRepository) services.UserService { } } -func (s *service) Create(ctx context.Context, projectID string) (*user.User, error) { - s.logger.InfoContext(ctx, "creating user", slog.String("project_id", projectID)) - usr := &user.User{ - ProjectID: projectID, - } +func (s *service) GetOrCreate(ctx context.Context, projectID, externalUserID, providerID string) (*user.User, error) { + s.logger.InfoContext(ctx, "getting or creating user", slog.String("project_id", projectID), slog.String("external_user_id", externalUserID), slog.String("provider_id", providerID)) - err := s.repo.Create(ctx, usr) - if err != nil { - s.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) + usr, err := s.getByExternal(ctx, externalUserID, providerID) + if err != nil && !errors.Is(err, domainErrors.ErrExternalUserNotFound) { + s.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) return nil, err } + if usr == nil { + usr, err = s.create(ctx, projectID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) + return nil, err + } + + _, err = s.createExternal(ctx, projectID, usr.ID, externalUserID, providerID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) + return nil, err + } + } + return usr, nil } -func (s *service) Get(ctx context.Context, userID string) (*user.User, error) { - s.logger.InfoContext(ctx, "getting user", slog.String("user_id", userID)) - usr, err := s.repo.Get(ctx, userID) +func (s *service) create(ctx context.Context, projectID string) (*user.User, error) { + s.logger.InfoContext(ctx, "creating user", slog.String("project_id", projectID)) + usr := &user.User{ + ProjectID: projectID, + } + + err := s.repo.Create(ctx, usr) if err != nil { - s.logger.ErrorContext(ctx, "failed to get user", logger.Error(err)) + s.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) return nil, err } return usr, nil } -func (s *service) GetByExternal(ctx context.Context, externalUserID, providerID string) (*user.User, error) { +func (s *service) getByExternal(ctx context.Context, externalUserID, providerID string) (*user.User, error) { s.logger.InfoContext(ctx, "getting user by external user", slog.String("external_user_id", externalUserID), slog.String("provider_id", providerID)) extUsrs, err := s.repo.FindExternalBy(ctx, s.repo.WithExternalUserID(externalUserID), s.repo.WithProviderID(providerID)) @@ -63,7 +78,7 @@ func (s *service) GetByExternal(ctx context.Context, externalUserID, providerID if len(extUsrs) == 0 { s.logger.ErrorContext(ctx, "external user not found", slog.String("external_user_id", externalUserID), slog.String("provider_id", providerID)) - return nil, domain.ErrExternalUserNotFound + return nil, domainErrors.ErrExternalUserNotFound } extUsr := extUsrs[0] @@ -76,7 +91,7 @@ func (s *service) GetByExternal(ctx context.Context, externalUserID, providerID return usr, nil } -func (s *service) CreateExternal(ctx context.Context, projectID, userID, externalUserID, providerID string) (*user.ExternalUser, error) { +func (s *service) createExternal(ctx context.Context, projectID, userID, externalUserID, providerID string) (*user.ExternalUser, error) { s.logger.InfoContext(ctx, "creating external user", slog.String("project_id", projectID)) usr, err := s.repo.Get(ctx, userID) @@ -87,23 +102,23 @@ func (s *service) CreateExternal(ctx context.Context, projectID, userID, externa if usr == nil { s.logger.ErrorContext(ctx, "user not found", slog.String("user_id", userID)) - return nil, domain.ErrUserNotFound + return nil, domainErrors.ErrUserNotFound } if usr.ProjectID != projectID { s.logger.ErrorContext(ctx, "user does not belong to project", slog.String("project_id", projectID), slog.String("user_id", userID)) - return nil, domain.ErrUserNotFound + return nil, domainErrors.ErrUserNotFound } extUsrs, err := s.repo.FindExternalBy(ctx, s.repo.WithUserID(userID), s.repo.WithProviderID(providerID)) - if err != nil && !errors.Is(err, domain.ErrExternalUserNotFound) { + if err != nil && !errors.Is(err, domainErrors.ErrExternalUserNotFound) { s.logger.ErrorContext(ctx, "failed to get external user", logger.Error(err)) return nil, err } if len(extUsrs) != 0 { s.logger.ErrorContext(ctx, "external user already exists for this user and provider", slog.String("user_id", userID), slog.String("provider_type", providerID)) - return nil, domain.ErrExternalUserAlreadyExists + return nil, domainErrors.ErrExternalUserAlreadyExists } extUsr := &user.ExternalUser{ diff --git a/internal/core/services/usersvc/svc_test.go b/internal/core/services/usersvc/svc_test.go index de4145e..943e56a 100644 --- a/internal/core/services/usersvc/svc_test.go +++ b/internal/core/services/usersvc/svc_test.go @@ -3,12 +3,12 @@ package usersvc import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "testing" "github.com/stretchr/testify/mock" - "go.openfort.xyz/shield/internal/core/domain" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/usermockedrepo" "go.openfort.xyz/shield/internal/core/domain/user" - "go.openfort.xyz/shield/internal/infrastructure/repositories/mocks/usermockedrepo" ) func TestCreateUser(t *testing.T) { @@ -73,10 +73,10 @@ func TestGetUser(t *testing.T) { { name: "not found", wantErr: true, - err: domain.ErrUserNotFound, + err: domainErrors.ErrUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrUserNotFound) + mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrUserNotFound) }, }, { @@ -128,16 +128,16 @@ func TestGetUserByExternal(t *testing.T) { { name: "external not found", wantErr: true, - err: domain.ErrExternalUserNotFound, + err: domainErrors.ErrExternalUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return(nil, domain.ErrExternalUserNotFound) + mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrExternalUserNotFound) }, }, { name: "external empty", wantErr: true, - err: domain.ErrExternalUserNotFound, + err: domainErrors.ErrExternalUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, nil) @@ -146,11 +146,11 @@ func TestGetUserByExternal(t *testing.T) { { name: "user not found", wantErr: true, - err: domain.ErrUserNotFound, + err: domainErrors.ErrUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{{}}, nil) - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrUserNotFound) + mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrUserNotFound) }, }, { @@ -203,16 +203,16 @@ func TestCreateExternalUser(t *testing.T) { { name: "user not found on repo", wantErr: true, - err: domain.ErrUserNotFound, + err: domainErrors.ErrUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domain.ErrUserNotFound) + mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrUserNotFound) }, }, { name: "user empty on repo", wantErr: true, - err: domain.ErrUserNotFound, + err: domainErrors.ErrUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, nil) @@ -221,7 +221,7 @@ func TestCreateExternalUser(t *testing.T) { { name: "user not found project mismatch", wantErr: true, - err: domain.ErrUserNotFound, + err: domainErrors.ErrUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{ProjectID: "noproject"}, nil) @@ -235,7 +235,7 @@ func TestCreateExternalUser(t *testing.T) { mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{{}}, nil) }, wantErr: true, - err: domain.ErrExternalUserAlreadyExists, + err: domainErrors.ErrExternalUserAlreadyExists, }, { name: "cant find external user", diff --git a/internal/infrastructure/authenticationmgr/apisecret.go b/internal/infrastructure/authenticationmgr/apisecret.go deleted file mode 100644 index 85b84e2..0000000 --- a/internal/infrastructure/authenticationmgr/apisecret.go +++ /dev/null @@ -1,43 +0,0 @@ -package authenticationmgr - -import ( - "context" - "log/slog" - - "go.openfort.xyz/shield/internal/core/ports/authentication" - "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/logger" - "golang.org/x/crypto/bcrypt" -) - -type apiSecret struct { - projectRepo repositories.ProjectRepository - logger *slog.Logger -} - -var _ authentication.APISecretAuthenticator = (*apiSecret)(nil) - -func newAPISecretAuthenticator(repository repositories.ProjectRepository) authentication.APISecretAuthenticator { - return &apiSecret{ - projectRepo: repository, - logger: logger.New("api_key_authenticator"), - } -} - -func (a *apiSecret) Authenticate(ctx context.Context, apiKey, apiSecret string) (string, error) { - a.logger.InfoContext(ctx, "authenticating api key") - - proj, err := a.projectRepo.GetByAPIKey(ctx, apiKey) - if err != nil { - a.logger.ErrorContext(ctx, "failed to authenticate api key", logger.Error(err)) - return "", err - } - - err = bcrypt.CompareHashAndPassword([]byte(proj.APISecret), []byte(apiSecret)) - if err != nil { - a.logger.ErrorContext(ctx, "failed to authenticate api secret", logger.Error(err)) - return "", err - } - - return proj.ID, nil -} diff --git a/internal/infrastructure/authenticationmgr/manager.go b/internal/infrastructure/authenticationmgr/manager.go deleted file mode 100644 index ce8a620..0000000 --- a/internal/infrastructure/authenticationmgr/manager.go +++ /dev/null @@ -1,89 +0,0 @@ -package authenticationmgr - -import ( - "context" - "errors" - "log/slog" - "strings" - - "go.openfort.xyz/shield/pkg/contexter" - "go.openfort.xyz/shield/pkg/logger" - - "go.openfort.xyz/shield/internal/core/domain" - "go.openfort.xyz/shield/internal/core/domain/provider" - "go.openfort.xyz/shield/internal/core/ports/authentication" - "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/internal/infrastructure/providersmgr" -) - -type Manager struct { - APISecretAuthenticator authentication.APISecretAuthenticator - UserAuthenticator authentication.UserAuthenticator - repo repositories.ProjectRepository - providerManager *providersmgr.Manager - userService services.UserService - mapOrigins map[string][]string - logger *slog.Logger -} - -func NewManager(repo repositories.ProjectRepository, providerManager *providersmgr.Manager, userService services.UserService) *Manager { - return &Manager{ - repo: repo, - APISecretAuthenticator: newAPISecretAuthenticator(repo), - providerManager: providerManager, - UserAuthenticator: newUserAuthenticator(repo, providerManager, userService), - userService: userService, - mapOrigins: make(map[string][]string), - logger: logger.New("authentication_manager"), - } -} - -func (m *Manager) GetAPISecretAuthenticator() authentication.APISecretAuthenticator { - return m.APISecretAuthenticator -} - -func (m *Manager) GetUserAuthenticator() authentication.UserAuthenticator { - return m.UserAuthenticator -} - -func (m *Manager) GetAuthProvider(providerStr string) (provider.Type, error) { - switch strings.ToLower(providerStr) { - case "openfort": - return provider.TypeOpenfort, nil - case "custom": - return provider.TypeCustom, nil - default: - return provider.TypeUnknown, domain.ErrUnknownProviderType - } -} - -func (m *Manager) PreRegisterUser(ctx context.Context, userID string, providerType provider.Type) (string, error) { - projID := contexter.GetProjectID(ctx) - prov, err := m.providerManager.GetProvider(ctx, projID, providerType) - if err != nil { - m.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) - return "", err - } - - usr, err := m.userService.GetByExternal(ctx, userID, prov.GetProviderID()) - if err != nil { - if !errors.Is(err, domain.ErrUserNotFound) && !errors.Is(err, domain.ErrExternalUserNotFound) { - m.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) - return "", err - } - usr, err = m.userService.Create(ctx, projID) - if err != nil { - m.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) - return "", err - } - - _, err = m.userService.CreateExternal(ctx, projID, usr.ID, userID, prov.GetProviderID()) - if err != nil { - m.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) - return "", err - } - } - - return usr.ID, nil -} diff --git a/internal/infrastructure/authenticationmgr/user.go b/internal/infrastructure/authenticationmgr/user.go deleted file mode 100644 index 6742f78..0000000 --- a/internal/infrastructure/authenticationmgr/user.go +++ /dev/null @@ -1,93 +0,0 @@ -package authenticationmgr - -import ( - "context" - "errors" - "log/slog" - - "go.openfort.xyz/shield/internal/core/domain" - "go.openfort.xyz/shield/internal/core/domain/provider" - "go.openfort.xyz/shield/internal/core/ports/authentication" - "go.openfort.xyz/shield/internal/core/ports/providers" - "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/internal/infrastructure/providersmgr" - "go.openfort.xyz/shield/pkg/logger" -) - -type user struct { - projectRepo repositories.ProjectRepository - providerManager *providersmgr.Manager - userService services.UserService - logger *slog.Logger -} - -var _ authentication.UserAuthenticator = (*user)(nil) - -func newUserAuthenticator(repository repositories.ProjectRepository, providerManager *providersmgr.Manager, userService services.UserService) authentication.UserAuthenticator { - return &user{ - projectRepo: repository, - providerManager: providerManager, - userService: userService, - logger: logger.New("api_key_authenticator"), - } -} - -func (a *user) Authenticate(ctx context.Context, apiKey, token string, providerType provider.Type, opts ...authentication.CustomOption) (*authentication.Authentication, error) { - a.logger.InfoContext(ctx, "authenticating api key") - - proj, err := a.projectRepo.GetByAPIKey(ctx, apiKey) - if err != nil { - a.logger.ErrorContext(ctx, "failed to authenticate api key", logger.Error(err)) - return nil, err - } - - prov, err := a.providerManager.GetProvider(ctx, proj.ID, providerType) - if err != nil { - a.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) - return nil, err - } - - var opt authentication.CustomOptions - for _, o := range opts { - o(&opt) - } - - var providerCustomOptions []providers.CustomOption - if opt.OpenfortProvider != nil { - providerCustomOptions = append(providerCustomOptions, providers.WithOpenfortProvider(*opt.OpenfortProvider)) - } - if opt.OpenfortTokenType != nil { - providerCustomOptions = append(providerCustomOptions, providers.WithOpenfortTokenType(*opt.OpenfortTokenType)) - } - - externalUserID, err := prov.Identify(ctx, token, providerCustomOptions...) - if err != nil { - a.logger.ErrorContext(ctx, "failed to identify user", logger.Error(err)) - return nil, err - } - - usr, err := a.userService.GetByExternal(ctx, externalUserID, prov.GetProviderID()) - if err != nil { - if !errors.Is(err, domain.ErrUserNotFound) && !errors.Is(err, domain.ErrExternalUserNotFound) { - a.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) - return nil, err - } - usr, err = a.userService.Create(ctx, proj.ID) - if err != nil { - a.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) - return nil, err - } - - _, err = a.userService.CreateExternal(ctx, proj.ID, usr.ID, externalUserID, prov.GetProviderID()) - if err != nil { - a.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) - return nil, err - } - } - - return &authentication.Authentication{ - UserID: usr.ID, - ProjectID: proj.ID, - }, nil -} diff --git a/internal/infrastructure/providersmgr/manager.go b/internal/infrastructure/providersmgr/manager.go deleted file mode 100644 index a8d9458..0000000 --- a/internal/infrastructure/providersmgr/manager.go +++ /dev/null @@ -1,57 +0,0 @@ -package providersmgr - -import ( - "context" - "errors" - "log/slog" - - "go.openfort.xyz/shield/internal/core/domain" - "go.openfort.xyz/shield/internal/core/domain/provider" - "go.openfort.xyz/shield/internal/core/ports/providers" - "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/logger" -) - -type Manager struct { - config *Config - repo repositories.ProviderRepository - logger *slog.Logger -} - -func NewManager(cfg *Config, repo repositories.ProviderRepository) *Manager { - return &Manager{ - config: cfg, - repo: repo, - logger: logger.New("provider_manager"), - } -} - -func (p *Manager) GetProvider(ctx context.Context, projectID string, providerType provider.Type) (providers.IdentityProvider, error) { - p.logger.InfoContext(ctx, "getting provider", slog.String("provider_type", string(providerType))) - - prov, err := p.repo.GetByProjectAndType(ctx, projectID, providerType) - if err != nil { - if errors.Is(err, domain.ErrProjectNotFound) { - return nil, ErrProviderNotConfigured - } - p.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) - return nil, err - } - - switch prov.Type { - case provider.TypeCustom: - config, ok := prov.Config.(*provider.CustomConfig) - if !ok { - return nil, ErrProviderConfigMismatch - } - return newCustomProvider(config), nil - case provider.TypeOpenfort: - config, ok := prov.Config.(*provider.OpenfortConfig) - if !ok { - return nil, ErrProviderConfigMismatch - } - return newOpenfortProvider(p.config, config), nil - default: - return nil, ErrProviderNotSupported - } -} diff --git a/pkg/jwk/errors.go b/pkg/jwk/errors.go new file mode 100644 index 0000000..0645678 --- /dev/null +++ b/pkg/jwk/errors.go @@ -0,0 +1,7 @@ +package jwk + +import "errors" + +var ( + ErrInvalidToken = errors.New("invalid token") +) diff --git a/internal/infrastructure/providersmgr/jwks.go b/pkg/jwk/jwk.go similarity index 81% rename from internal/infrastructure/providersmgr/jwks.go rename to pkg/jwk/jwk.go index e76dff8..a888d50 100644 --- a/internal/infrastructure/providersmgr/jwks.go +++ b/pkg/jwk/jwk.go @@ -1,11 +1,11 @@ -package providersmgr +package jwk import ( "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" ) -func validateJWKs(token, jwkURL string) (string, error) { +func Validate(token, jwkURL string) (string, error) { k, err := keyfunc.NewDefault([]string{jwkURL}) if err != nil { return "", err From 3d36e7b6d16822bd2c6f887032db4f00c1ff82d9 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Fri, 5 Jul 2024 19:02:51 +0200 Subject: [PATCH 03/10] feat: encryption sessions --- Dockerfile | 2 +- di/wire.go | 64 ++++++++++--- di/wire_gen.go | 94 +++++++++++++++---- internal/adapters/authenticators/factory.go | 30 ------ .../identity/custom_identity/custom.go | 6 +- .../authenticators/identity/errors.go | 15 --- .../authenticators/identity/factory.go | 8 +- .../identity/openfort_identity/openfort.go | 5 +- .../aes_encryption_strategy/strategy.go | 19 ++++ internal/adapters/encryption/factory.go | 44 +++++++++ .../encryption/plain_builder/builder.go | 48 ++++++++++ .../encryption/session_builder/builder.go | 64 +++++++++++++ .../sss_reconstruction_strategy/strategy.go | 30 ++++++ internal/adapters/handlers/rest/api/errors.go | 1 + .../handlers/rest/projecthdl/errors.go | 2 + .../handlers/rest/projecthdl/handler.go | 46 +++++++++ .../handlers/rest/projecthdl/types.go | 8 ++ internal/adapters/handlers/rest/server.go | 36 ++++--- .../adapters/handlers/rest/sharehdl/errors.go | 3 +- .../handlers/rest/sharehdl/handler.go | 8 ++ .../adapters/handlers/rest/sharehdl/parser.go | 4 + .../adapters/handlers/rest/sharehdl/types.go | 16 ++-- internal/adapters/repositories/bunt/client.go | 10 ++ .../bunt/encryptionpartsrepo/repo.go | 20 ++-- internal/adapters/repositories/sql/client.go | 5 +- internal/adapters/repositories/sql/config.go | 2 +- .../repositories/sql/projectrepo/repo.go | 2 +- internal/applications/projectapp/app.go | 87 ++++++++++++----- internal/applications/projectapp/errors.go | 4 + internal/applications/shareapp/app.go | 55 +++++++---- internal/applications/shareapp/errors.go | 4 + internal/core/domain/errors/project.go | 1 + internal/core/domain/errors/provider.go | 13 ++- internal/core/ports/builders/encryption.go | 11 +++ internal/core/ports/factories/encryption.go | 76 ++------------- ...encryptionparts.go => encryption_parts.go} | 0 internal/core/ports/strategies/encryption.go | 6 ++ .../core/ports/strategies/reconstruction.go | 6 ++ internal/core/services/sharesvc/svc.go | 17 ++-- pkg/cypher/cypher.go | 34 +------ pkg/random/random.go | 24 +++++ 41 files changed, 658 insertions(+), 272 deletions(-) delete mode 100644 internal/adapters/authenticators/identity/errors.go create mode 100644 internal/adapters/encryption/aes_encryption_strategy/strategy.go create mode 100644 internal/adapters/encryption/factory.go create mode 100644 internal/adapters/encryption/plain_builder/builder.go create mode 100644 internal/adapters/encryption/session_builder/builder.go create mode 100644 internal/adapters/encryption/sss_reconstruction_strategy/strategy.go create mode 100644 internal/core/ports/builders/encryption.go rename internal/core/ports/repositories/{encryptionparts.go => encryption_parts.go} (100%) create mode 100644 internal/core/ports/strategies/encryption.go create mode 100644 internal/core/ports/strategies/reconstruction.go create mode 100644 pkg/random/random.go diff --git a/Dockerfile b/Dockerfile index c7134d7..f9a0106 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,5 +8,5 @@ FROM scratch WORKDIR /app COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=builder /app/app /usr/bin/ -COPY internal/infrastructure/repositories/sql/migrations /app/internal/infrastructure/repositories/sql/migrations +COPY internal/infrastructure/adapters/sql/migrations /app/internal/infrastructure/adapters/sql/migrations ENTRYPOINT ["app"] \ No newline at end of file diff --git a/di/wire.go b/di/wire.go index 18cb567..dab64ca 100644 --- a/di/wire.go +++ b/di/wire.go @@ -5,10 +5,13 @@ package di import ( "github.com/google/wire" - "go.openfort.xyz/shield/internal/adapters/authenticationmgr" - identity2 "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/adapters/authenticators" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity" "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + "go.openfort.xyz/shield/internal/adapters/encryption" "go.openfort.xyz/shield/internal/adapters/handlers/rest" + "go.openfort.xyz/shield/internal/adapters/repositories/bunt" + "go.openfort.xyz/shield/internal/adapters/repositories/bunt/encryptionpartsrepo" "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/adapters/repositories/sql/projectrepo" "go.openfort.xyz/shield/internal/adapters/repositories/sql/providerrepo" @@ -16,6 +19,7 @@ import ( "go.openfort.xyz/shield/internal/adapters/repositories/sql/userrepo" "go.openfort.xyz/shield/internal/applications/projectapp" "go.openfort.xyz/shield/internal/applications/shareapp" + "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/internal/core/services/projectsvc" @@ -33,6 +37,14 @@ func ProvideSQL() (c *sql.Client, err error) { return } +func ProvideBuntDB() (c *bunt.Client, err error) { + wire.Build( + bunt.New, + ) + + return +} + func ProvideSQLUserRepository() (r repositories.UserRepository, err error) { wire.Build( userrepo.New, @@ -69,6 +81,15 @@ func ProvideSQLShareRepository() (r repositories.ShareRepository, err error) { return } +func ProvideInMemoryEncryptionPartsRepository() (r repositories.EncryptionPartsRepository, err error) { + wire.Build( + encryptionpartsrepo.New, + ProvideBuntDB, + ) + + return +} + func ProvideProjectService() (s services.ProjectService, err error) { wire.Build( projectsvc.New, @@ -96,20 +117,21 @@ func ProvideUserService() (s services.UserService, err error) { return } -func ProvideShareService() (s services.ShareService, err error) { +func ProvideEncryptionFactory() (f factories.EncryptionFactory, err error) { wire.Build( - sharesvc.New, - ProvideSQLShareRepository, + encryption.NewEncryptionFactory, + ProvideInMemoryEncryptionPartsRepository, + ProvideSQLProjectRepository, ) return } -func ProvideProviderManager() (pm *identity2.identityFactory, err error) { +func ProvideShareService() (s services.ShareService, err error) { wire.Build( - identity2.NewIdentityFactory, - openfort_identity.GetConfigFromEnv, - ProvideSQLProviderRepository, + sharesvc.New, + ProvideSQLShareRepository, + ProvideEncryptionFactory, ) return @@ -121,6 +143,7 @@ func ProvideShareApplication() (a *shareapp.ShareApplication, err error) { ProvideShareService, ProvideSQLShareRepository, ProvideSQLProjectRepository, + ProvideEncryptionFactory, ) return @@ -134,17 +157,28 @@ func ProvideProjectApplication() (a *projectapp.ProjectApplication, err error) { ProvideProviderService, ProvideSQLProviderRepository, ProvideSQLShareRepository, + ProvideEncryptionFactory, + ProvideInMemoryEncryptionPartsRepository, ) return } -func ProvideAuthenticationManager() (am *authenticationmgr.Manager, err error) { +func ProvideAuthenticationFactory() (f factories.AuthenticationFactory, err error) { wire.Build( - authenticationmgr.NewManager, - ProvideSQLProjectRepository, - ProvideProviderManager, + authenticators.NewAuthenticatorFactory, ProvideUserService, + ProvideSQLProjectRepository, + ) + + return +} + +func ProvideIdentityFactory() (f factories.IdentityFactory, err error) { + wire.Build( + identity.NewIdentityFactory, + openfort_identity.GetConfigFromEnv, + ProvideSQLProviderRepository, ) return @@ -156,7 +190,9 @@ func ProvideRESTServer() (s *rest.Server, err error) { rest.GetConfigFromEnv, ProvideShareApplication, ProvideProjectApplication, - ProvideAuthenticationManager, + ProvideUserService, + ProvideAuthenticationFactory, + ProvideIdentityFactory, ) return diff --git a/di/wire_gen.go b/di/wire_gen.go index fa352a1..c312b64 100644 --- a/di/wire_gen.go +++ b/di/wire_gen.go @@ -7,10 +7,13 @@ package di import ( - "go.openfort.xyz/shield/internal/adapters/authenticationmgr" - identity2 "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/adapters/authenticators" + "go.openfort.xyz/shield/internal/adapters/authenticators/identity" "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + "go.openfort.xyz/shield/internal/adapters/encryption" "go.openfort.xyz/shield/internal/adapters/handlers/rest" + "go.openfort.xyz/shield/internal/adapters/repositories/bunt" + "go.openfort.xyz/shield/internal/adapters/repositories/bunt/encryptionpartsrepo" "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/adapters/repositories/sql/projectrepo" "go.openfort.xyz/shield/internal/adapters/repositories/sql/providerrepo" @@ -18,6 +21,7 @@ import ( "go.openfort.xyz/shield/internal/adapters/repositories/sql/userrepo" "go.openfort.xyz/shield/internal/applications/projectapp" "go.openfort.xyz/shield/internal/applications/shareapp" + "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/internal/core/services/projectsvc" @@ -40,6 +44,14 @@ func ProvideSQL() (*sql.Client, error) { return client, nil } +func ProvideBuntDB() (*bunt.Client, error) { + client, err := bunt.New() + if err != nil { + return nil, err + } + return client, nil +} + func ProvideSQLUserRepository() (repositories.UserRepository, error) { client, err := ProvideSQL() if err != nil { @@ -76,6 +88,15 @@ func ProvideSQLShareRepository() (repositories.ShareRepository, error) { return shareRepository, nil } +func ProvideInMemoryEncryptionPartsRepository() (repositories.EncryptionPartsRepository, error) { + client, err := ProvideBuntDB() + if err != nil { + return nil, err + } + encryptionPartsRepository := encryptionpartsrepo.New(client) + return encryptionPartsRepository, nil +} + func ProvideProjectService() (services.ProjectService, error) { projectRepository, err := ProvideSQLProjectRepository() if err != nil { @@ -103,26 +124,30 @@ func ProvideUserService() (services.UserService, error) { return userService, nil } -func ProvideShareService() (services.ShareService, error) { - shareRepository, err := ProvideSQLShareRepository() +func ProvideEncryptionFactory() (factories.EncryptionFactory, error) { + encryptionPartsRepository, err := ProvideInMemoryEncryptionPartsRepository() if err != nil { return nil, err } - shareService := sharesvc.New(shareRepository) - return shareService, nil + projectRepository, err := ProvideSQLProjectRepository() + if err != nil { + return nil, err + } + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepository, projectRepository) + return encryptionFactory, nil } -func ProvideProviderManager() (*identity2.identityFactory, error) { - config, err := openfort_identity.GetConfigFromEnv() +func ProvideShareService() (services.ShareService, error) { + shareRepository, err := ProvideSQLShareRepository() if err != nil { return nil, err } - providerRepository, err := ProvideSQLProviderRepository() + encryptionFactory, err := ProvideEncryptionFactory() if err != nil { return nil, err } - manager := identity2.NewIdentityFactory(config, providerRepository) - return manager, nil + shareService := sharesvc.New(shareRepository, encryptionFactory) + return shareService, nil } func ProvideShareApplication() (*shareapp.ShareApplication, error) { @@ -138,7 +163,11 @@ func ProvideShareApplication() (*shareapp.ShareApplication, error) { if err != nil { return nil, err } - shareApplication := shareapp.New(shareService, shareRepository, projectRepository) + encryptionFactory, err := ProvideEncryptionFactory() + if err != nil { + return nil, err + } + shareApplication := shareapp.New(shareService, shareRepository, projectRepository, encryptionFactory) return shareApplication, nil } @@ -163,25 +192,42 @@ func ProvideProjectApplication() (*projectapp.ProjectApplication, error) { if err != nil { return nil, err } - projectApplication := projectapp.New(projectService, projectRepository, providerService, providerRepository, shareRepository) + encryptionFactory, err := ProvideEncryptionFactory() + if err != nil { + return nil, err + } + encryptionPartsRepository, err := ProvideInMemoryEncryptionPartsRepository() + if err != nil { + return nil, err + } + projectApplication := projectapp.New(projectService, projectRepository, providerService, providerRepository, shareRepository, encryptionFactory, encryptionPartsRepository) return projectApplication, nil } -func ProvideAuthenticationManager() (*authenticationmgr.Manager, error) { +func ProvideAuthenticationFactory() (factories.AuthenticationFactory, error) { projectRepository, err := ProvideSQLProjectRepository() if err != nil { return nil, err } - manager, err := ProvideProviderManager() + userService, err := ProvideUserService() if err != nil { return nil, err } - userService, err := ProvideUserService() + authenticationFactory := authenticators.NewAuthenticatorFactory(projectRepository, userService) + return authenticationFactory, nil +} + +func ProvideIdentityFactory() (factories.IdentityFactory, error) { + config, err := openfort_identity.GetConfigFromEnv() + if err != nil { + return nil, err + } + providerRepository, err := ProvideSQLProviderRepository() if err != nil { return nil, err } - authenticationmgrManager := authenticationmgr.NewManager(projectRepository, manager, userService) - return authenticationmgrManager, nil + identityFactory := identity.NewIdentityFactory(config, providerRepository) + return identityFactory, nil } func ProvideRESTServer() (*rest.Server, error) { @@ -197,10 +243,18 @@ func ProvideRESTServer() (*rest.Server, error) { if err != nil { return nil, err } - manager, err := ProvideAuthenticationManager() + authenticationFactory, err := ProvideAuthenticationFactory() + if err != nil { + return nil, err + } + identityFactory, err := ProvideIdentityFactory() + if err != nil { + return nil, err + } + userService, err := ProvideUserService() if err != nil { return nil, err } - server := rest.New(config, projectApplication, shareApplication, manager) + server := rest.New(config, projectApplication, shareApplication, authenticationFactory, identityFactory, userService) return server, nil } diff --git a/internal/adapters/authenticators/factory.go b/internal/adapters/authenticators/factory.go index b810a09..eea61c1 100644 --- a/internal/adapters/authenticators/factory.go +++ b/internal/adapters/authenticators/factory.go @@ -27,33 +27,3 @@ func (f *authenticatorFactory) CreateProjectAuthenticator(apiKey, apiSecret stri func (f *authenticatorFactory) CreateUserAuthenticator(apiKey, token string, identityFactory factories.Identity) factories.Authenticator { return user_authenticator.NewUserAuthenticator(f.projectRepo, f.userService, apiKey, token, identityFactory) } - -// func (m *Manager) PreRegisterUser(ctx context.Context, userID string, providerType provider.Type) (string, error) { -// projID := contexter.GetProjectID(ctx) -// prov, err := m.providerManager.GetProvider(ctx, projID, providerType) -// if err != nil { -// m.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) -// return "", err -// } -// -// usr, err := m.userService.GetByExternal(ctx, userID, prov.GetProviderID()) -// if err != nil { -// if !errors.Is(err, domainErrors.ErrUserNotFound) && !errors.Is(err, domainErrors.ErrExternalUserNotFound) { -// m.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) -// return "", err -// } -// usr, err = m.userService.Create(ctx, projID) -// if err != nil { -// m.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) -// return "", err -// } -// -// _, err = m.userService.CreateExternal(ctx, projID, usr.ID, userID, prov.GetProviderID()) -// if err != nil { -// m.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) -// return "", err -// } -// } -// -// return usr.ID, nil -//} diff --git a/internal/adapters/authenticators/identity/custom_identity/custom.go b/internal/adapters/authenticators/identity/custom_identity/custom.go index 7ffcd36..5ed9982 100644 --- a/internal/adapters/authenticators/identity/custom_identity/custom.go +++ b/internal/adapters/authenticators/identity/custom_identity/custom.go @@ -2,7 +2,7 @@ package custom_identity import ( "context" - "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/pkg/jwk" "log/slog" @@ -42,7 +42,7 @@ func (c *CustomIdentityFactory) Identify(ctx context.Context, token string) (str case c.config.JWK != "": externalUserID, err = jwk.Validate(token, c.config.JWK) // TODO parse error default: - return "", identity.ErrProviderMisconfigured + return "", errors.ErrProviderMisconfigured } if err != nil { c.logger.ErrorContext(ctx, "failed to validate jwt", logger.Error(err)) @@ -68,7 +68,7 @@ func (c *CustomIdentityFactory) validatePEM(token string) (string, error) { return jwt.ParseEdPublicKeyFromPEM([]byte(c.config.PEM)) } default: - return "", identity.ErrCertTypeNotSupported + return "", errors.ErrCertTypeNotSupported } parsed, err := jwt.Parse(token, keyFunc) diff --git a/internal/adapters/authenticators/identity/errors.go b/internal/adapters/authenticators/identity/errors.go deleted file mode 100644 index 50d1dae..0000000 --- a/internal/adapters/authenticators/identity/errors.go +++ /dev/null @@ -1,15 +0,0 @@ -package identity - -import "errors" - -var ( - ErrProviderNotSupported = errors.New("provider not supported") - ErrProviderNotConfigured = errors.New("provider not configured") - ErrProviderConfigMismatch = errors.New("provider config mismatch") - ErrInvalidToken = errors.New("invalid token") - ErrMissingOpenfortProvider = errors.New("missing openfort provider") - ErrMissingOpenfortTokenType = errors.New("missing openfort token type") - ErrUnexpectedStatusCode = errors.New("unexpected status code") - ErrCertTypeNotSupported = errors.New("certificate type not supported") - ErrProviderMisconfigured = errors.New("provider misconfigured") -) diff --git a/internal/adapters/authenticators/identity/factory.go b/internal/adapters/authenticators/identity/factory.go index 66dd784..c3efc6f 100644 --- a/internal/adapters/authenticators/identity/factory.go +++ b/internal/adapters/authenticators/identity/factory.go @@ -32,7 +32,7 @@ func (p *identityFactory) CreateCustomIdentity(ctx context.Context, apiKey strin prov, err := p.repo.GetByAPIKeyAndType(ctx, apiKey, provider.TypeCustom) if err != nil { if errors.Is(err, domainErrors.ErrProjectNotFound) { - return nil, ErrProviderNotConfigured + return nil, domainErrors.ErrProviderNotConfigured } p.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, err @@ -40,7 +40,7 @@ func (p *identityFactory) CreateCustomIdentity(ctx context.Context, apiKey strin config, ok := prov.Config.(*provider.CustomConfig) if !ok { - return nil, ErrProviderConfigMismatch + return nil, domainErrors.ErrProviderConfigMismatch } return custom_identity.NewCustomIdentityFactory(config), nil @@ -50,7 +50,7 @@ func (p *identityFactory) CreateOpenfortIdentity(ctx context.Context, apiKey str prov, err := p.repo.GetByAPIKeyAndType(ctx, apiKey, provider.TypeOpenfort) if err != nil { if errors.Is(err, domainErrors.ErrProjectNotFound) { - return nil, ErrProviderNotConfigured + return nil, domainErrors.ErrProviderNotConfigured } p.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) return nil, err @@ -58,7 +58,7 @@ func (p *identityFactory) CreateOpenfortIdentity(ctx context.Context, apiKey str config, ok := prov.Config.(*provider.OpenfortConfig) if !ok { - return nil, ErrProviderConfigMismatch + return nil, domainErrors.ErrProviderConfigMismatch } return openfort_identity.NewOpenfortIdentityFactory(p.config, config, authenticationProvider, tokenType), nil diff --git a/internal/adapters/authenticators/identity/openfort_identity/openfort.go b/internal/adapters/authenticators/identity/openfort_identity/openfort.go index f7719b6..5d3b257 100644 --- a/internal/adapters/authenticators/identity/openfort_identity/openfort.go +++ b/internal/adapters/authenticators/identity/openfort_identity/openfort.go @@ -5,7 +5,7 @@ import ( "context" "encoding/json" "fmt" - "go.openfort.xyz/shield/internal/adapters/authenticators/identity" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/pkg/jwk" "io" @@ -87,7 +87,8 @@ func (o *OpenfortIdentityFactory) thirdParty(ctx context.Context, token, authent defer resp.Body.Close() if resp.StatusCode/100 != 2 { - return "", identity.ErrUnexpectedStatusCode + o.logger.ErrorContext(ctx, "unexpected status code", slog.Int("status_code", resp.StatusCode)) + return "", domainErrors.ErrUnexpectedStatusCode } rawResponse, err := io.ReadAll(resp.Body) diff --git a/internal/adapters/encryption/aes_encryption_strategy/strategy.go b/internal/adapters/encryption/aes_encryption_strategy/strategy.go new file mode 100644 index 0000000..057518c --- /dev/null +++ b/internal/adapters/encryption/aes_encryption_strategy/strategy.go @@ -0,0 +1,19 @@ +package aes_encryption_strategy + +import "go.openfort.xyz/shield/pkg/cypher" + +type AESEncryptionStrategy struct { + key string +} + +func NewAESEncryptionStrategy(key string) *AESEncryptionStrategy { + return &AESEncryptionStrategy{key: key} +} + +func (s *AESEncryptionStrategy) Encrypt(data string) (string, error) { + return cypher.Encrypt(data, s.key) +} + +func (s *AESEncryptionStrategy) Decrypt(data string) (string, error) { + return cypher.Decrypt(data, s.key) +} diff --git a/internal/adapters/encryption/factory.go b/internal/adapters/encryption/factory.go new file mode 100644 index 0000000..69efc85 --- /dev/null +++ b/internal/adapters/encryption/factory.go @@ -0,0 +1,44 @@ +package encryption + +import ( + "errors" + "go.openfort.xyz/shield/internal/adapters/encryption/aes_encryption_strategy" + "go.openfort.xyz/shield/internal/adapters/encryption/plain_builder" + "go.openfort.xyz/shield/internal/adapters/encryption/session_builder" + "go.openfort.xyz/shield/internal/adapters/encryption/sss_reconstruction_strategy" + "go.openfort.xyz/shield/internal/core/ports/builders" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/internal/core/ports/strategies" +) + +type encryptionFactory struct { + encryptionPartsRepo repositories.EncryptionPartsRepository + projectRepo repositories.ProjectRepository +} + +func NewEncryptionFactory(encryptionPartsRepo repositories.EncryptionPartsRepository, projectRepo repositories.ProjectRepository) factories.EncryptionFactory { + return &encryptionFactory{ + encryptionPartsRepo: encryptionPartsRepo, + projectRepo: projectRepo, + } +} + +func (e *encryptionFactory) CreateEncryptionKeyBuilder(builderType factories.EncryptionKeyBuilderType) (builders.EncryptionKeyBuilder, error) { + switch builderType { + case factories.Plain: + return plain_builder.NewEncryptionKeyBuilder(e.projectRepo), nil + case factories.Session: + return session_builder.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo), nil + } + + return nil, errors.New("invalid builder type") //TODO extract error +} + +func (e *encryptionFactory) CreateReconstructionStrategy() strategies.ReconstructionStrategy { + return sss_reconstruction_strategy.NewSSSReconstructionStrategy() +} + +func (e *encryptionFactory) CreateEncryptionStrategy(key string) strategies.EncryptionStrategy { + return aes_encryption_strategy.NewAESEncryptionStrategy(key) +} diff --git a/internal/adapters/encryption/plain_builder/builder.go b/internal/adapters/encryption/plain_builder/builder.go new file mode 100644 index 0000000..c1bbcf1 --- /dev/null +++ b/internal/adapters/encryption/plain_builder/builder.go @@ -0,0 +1,48 @@ +package plain_builder + +import ( + "context" + "errors" + "go.openfort.xyz/shield/internal/core/ports/builders" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/pkg/cypher" +) + +type plainBuilder struct { + projectPart string + databasePart string + projectRepo repositories.ProjectRepository +} + +func NewEncryptionKeyBuilder(repo repositories.ProjectRepository) builders.EncryptionKeyBuilder { + return &plainBuilder{ + projectRepo: repo, + } +} + +func (b *plainBuilder) SetProjectPart(ctx context.Context, identifier string) error { + b.projectPart = identifier + return nil +} + +func (b *plainBuilder) SetDatabasePart(ctx context.Context, identifier string) error { + part, err := b.projectRepo.GetEncryptionPart(ctx, identifier) + if err != nil { + return err + } + + b.databasePart = part + return nil +} + +func (b *plainBuilder) Build(ctx context.Context) (string, error) { + if b.projectPart == "" { + return "", errors.New("project part is required") // TODO extract error + } + + if b.databasePart == "" { + return "", errors.New("database part is required") // TODO extract error + } + + return cypher.ReconstructEncryptionKey(b.projectPart, b.databasePart) +} diff --git a/internal/adapters/encryption/session_builder/builder.go b/internal/adapters/encryption/session_builder/builder.go new file mode 100644 index 0000000..ed455ec --- /dev/null +++ b/internal/adapters/encryption/session_builder/builder.go @@ -0,0 +1,64 @@ +package session_builder + +import ( + "context" + "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/builders" + "go.openfort.xyz/shield/internal/core/ports/repositories" + "go.openfort.xyz/shield/pkg/cypher" +) + +type sessionBuilder struct { + projectPart string + databasePart string + encryptionPartsRepo repositories.EncryptionPartsRepository + projectRepo repositories.ProjectRepository +} + +func NewEncryptionKeyBuilder(encryptionPartsRepo repositories.EncryptionPartsRepository, projectRepository repositories.ProjectRepository) builders.EncryptionKeyBuilder { + return &sessionBuilder{ + encryptionPartsRepo: encryptionPartsRepo, + projectRepo: projectRepository, + } +} + +func (b *sessionBuilder) SetProjectPart(ctx context.Context, identifier string) error { + part, err := b.encryptionPartsRepo.Get(ctx, identifier) + if err != nil { + if errors.Is(err, domainErrors.ErrEncryptionPartNotFound) { + return domainErrors.ErrInvalidEncryptionSession + } + return err + } + + err = b.encryptionPartsRepo.Delete(ctx, identifier) + if err != nil { + return err + } + + b.projectPart = part + return nil +} + +func (b *sessionBuilder) SetDatabasePart(ctx context.Context, identifier string) error { + part, err := b.projectRepo.GetEncryptionPart(ctx, identifier) + if err != nil { + return err + } + + b.databasePart = part + return nil +} + +func (b *sessionBuilder) Build(ctx context.Context) (string, error) { + if b.projectPart == "" { + return "", errors.New("project part is required") // TODO extract error + } + + if b.databasePart == "" { + return "", errors.New("database part is required") // TODO extract error + } + + return cypher.ReconstructEncryptionKey(b.projectPart, b.databasePart) +} diff --git a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go new file mode 100644 index 0000000..c8da2b4 --- /dev/null +++ b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go @@ -0,0 +1,30 @@ +package sss_reconstruction_strategy + +import ( + "errors" + "go.openfort.xyz/shield/internal/core/ports/strategies" + "go.openfort.xyz/shield/pkg/cypher" +) + +type SSSReconstructionStrategy struct{} + +func NewSSSReconstructionStrategy() strategies.ReconstructionStrategy { + return &SSSReconstructionStrategy{} +} + +func (s *SSSReconstructionStrategy) Split(data string) ([]string, error) { + firstPart, secondPart, err := cypher.SplitEncryptionKey(data) + if err != nil { + return nil, err + } + + return []string{firstPart, secondPart}, nil +} + +func (s *SSSReconstructionStrategy) Reconstruct(parts []string) (string, error) { + if len(parts) != 2 { + return "", errors.New("invalid number of parts") //TODO extract error + } + + return cypher.ReconstructEncryptionKey(parts[0], parts[1]) +} diff --git a/internal/adapters/handlers/rest/api/errors.go b/internal/adapters/handlers/rest/api/errors.go index eafbd00..cca248d 100644 --- a/internal/adapters/handlers/rest/api/errors.go +++ b/internal/adapters/handlers/rest/api/errors.go @@ -44,6 +44,7 @@ var ( ErrEncryptionNotConfigured = &Error{"Encryption not configured", "EC_MISSING", http.StatusConflict} ErrJWKPemConflict = &Error{"JWK and PEM cannot be set at the same time", "PV_CFG_INVALID", http.StatusConflict} ErrInvalidEncryptionPart = &Error{"Invalid encryption part", "EC_INVALID", http.StatusBadRequest} + ErrInvalidEncryptionSession = &Error{"Invalid encryption session", "EC_INVALID", http.StatusBadRequest} ErrEncryptionPartAlreadyExists = &Error{"Encryption part already exists", "EC_EXISTS", http.StatusConflict} ErrMissingAPIKey = &Error{"Missing API key", "A_MISSING", http.StatusUnauthorized} diff --git a/internal/adapters/handlers/rest/projecthdl/errors.go b/internal/adapters/handlers/rest/projecthdl/errors.go index 9bf9c75..7547e59 100644 --- a/internal/adapters/handlers/rest/projecthdl/errors.go +++ b/internal/adapters/handlers/rest/projecthdl/errors.go @@ -30,6 +30,8 @@ func fromApplicationError(err error) *api.Error { return api.ErrProviderNotFound case errors.Is(err, projectapp.ErrInvalidEncryptionPart): return api.ErrInvalidEncryptionPart + case errors.Is(err, projectapp.ErrInvalidEncryptionSession): + return api.ErrInvalidEncryptionSession case errors.Is(err, projectapp.ErrEncryptionPartAlreadyExists): return api.ErrEncryptionPartAlreadyExists case errors.Is(err, projectapp.ErrEncryptionNotConfigured): diff --git a/internal/adapters/handlers/rest/projecthdl/handler.go b/internal/adapters/handlers/rest/projecthdl/handler.go index fd16818..a55c115 100644 --- a/internal/adapters/handlers/rest/projecthdl/handler.go +++ b/internal/adapters/handlers/rest/projecthdl/handler.go @@ -346,6 +346,52 @@ func (h *Handler) EncryptProjectShares(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } +// RegisterEncryptionSession registers a session with a one-time encryption key for a project +// @Summary Register encryption session +// @Description Register a session with a one-time encryption key for a project +// @Tags Project +// @Accept json +// @Produce json +// @Param X-API-Key header string true "API Key" +// @Param X-API-Secret header string true "API Secret" +// @Param registerEncryptionSessionRequest body RegisterEncryptionSessionRequest true "Add Allowed Origin Request" +// @Success 200 {object} RegisterEncryptionSessionResponse "Encryption session registered successfully" +// @Failure 400 "Bad Request" +// @Failure 500 {object} api.Error "Internal Server Error" +// @Router /project/encryption-session [post] +func (h *Handler) RegisterEncryptionSession(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + h.logger.InfoContext(ctx, "registering encryption session") + + body, err := io.ReadAll(r.Body) + if err != nil { + api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to read request body")) + return + } + + var req RegisterEncryptionSessionRequest + err = json.Unmarshal(body, &req) + if err != nil { + api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to parse request body")) + return + } + + sessionID, err := h.app.RegisterEncryptionSession(ctx, req.EncryptionPart) + if err != nil { + api.RespondWithError(w, fromApplicationError(err)) + return + } + + resp, err := json.Marshal(RegisterEncryptionSessionResponse{SessionID: sessionID}) + if err != nil { + api.RespondWithError(w, api.ErrInternal) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write(resp) +} + // RegisterEncryptionKey registers an encryption key for a project // @Summary Register encryption key // @Description Register an encryption key for a project diff --git a/internal/adapters/handlers/rest/projecthdl/types.go b/internal/adapters/handlers/rest/projecthdl/types.go index 15b6ad4..2bd839c 100644 --- a/internal/adapters/handlers/rest/projecthdl/types.go +++ b/internal/adapters/handlers/rest/projecthdl/types.go @@ -83,3 +83,11 @@ type EncryptBodyRequest struct { type RegisterEncryptionKeyResponse struct { EncryptionPart string `json:"encryption_part"` } + +type RegisterEncryptionSessionRequest struct { + EncryptionPart string `json:"encryption_part"` +} + +type RegisterEncryptionSessionResponse struct { + SessionID string `json:"session_id"` +} diff --git a/internal/adapters/handlers/rest/server.go b/internal/adapters/handlers/rest/server.go index f055f1d..f47fc4a 100644 --- a/internal/adapters/handlers/rest/server.go +++ b/internal/adapters/handlers/rest/server.go @@ -3,13 +3,14 @@ package rest import ( "context" "fmt" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/ports/services" "log/slog" "net/http" "strings" "github.com/gorilla/mux" "github.com/rs/cors" - "go.openfort.xyz/shield/internal/adapters/authenticationmgr" "go.openfort.xyz/shield/internal/adapters/handlers/rest/authmdw" "go.openfort.xyz/shield/internal/adapters/handlers/rest/projecthdl" "go.openfort.xyz/shield/internal/adapters/handlers/rest/ratelimitermdw" @@ -23,23 +24,27 @@ import ( // Server is the REST server for the shield API type Server struct { - projectApp *projectapp.ProjectApplication - shareApp *shareapp.ShareApplication - authManager *authenticationmgr.Manager - server *http.Server - logger *slog.Logger - config *Config + projectApp *projectapp.ProjectApplication + shareApp *shareapp.ShareApplication + server *http.Server + logger *slog.Logger + config *Config + authenticationFactory factories.AuthenticationFactory + identityFactory factories.IdentityFactory + userService services.UserService } // New creates a new REST server -func New(cfg *Config, projectApp *projectapp.ProjectApplication, shareApp *shareapp.ShareApplication, authManager *authenticationmgr.Manager) *Server { +func New(cfg *Config, projectApp *projectapp.ProjectApplication, shareApp *shareapp.ShareApplication, authenticationFactory factories.AuthenticationFactory, identityFactory factories.IdentityFactory, userService services.UserService) *Server { return &Server{ - projectApp: projectApp, - shareApp: shareApp, - authManager: authManager, - server: new(http.Server), - logger: logger.New("rest_server"), - config: cfg, + projectApp: projectApp, + shareApp: shareApp, + server: new(http.Server), + logger: logger.New("rest_server"), + config: cfg, + authenticationFactory: authenticationFactory, + identityFactory: identityFactory, + userService: userService, } } @@ -47,7 +52,7 @@ func New(cfg *Config, projectApp *projectapp.ProjectApplication, shareApp *share func (s *Server) Start(ctx context.Context) error { projectHdl := projecthdl.New(s.projectApp) shareHdl := sharehdl.New(s.shareApp) - authMdw := authmdw.New(s.authManager) + authMdw := authmdw.New(s.authenticationFactory, s.identityFactory, s.userService) rateLimiterMdw := ratelimitermdw.New(s.config.RPS) r := mux.NewRouter() @@ -64,6 +69,7 @@ func (s *Server) Start(ctx context.Context) error { p.HandleFunc("/providers/{provider}", projectHdl.UpdateProvider).Methods(http.MethodPut) p.HandleFunc("/providers/{provider}", projectHdl.DeleteProvider).Methods(http.MethodDelete) p.HandleFunc("/encrypt", projectHdl.EncryptProjectShares).Methods(http.MethodPost) + p.HandleFunc("/encryption-session", projectHdl.RegisterEncryptionSession).Methods(http.MethodPost) p.HandleFunc("/encryption-key", projectHdl.RegisterEncryptionKey).Methods(http.MethodPost) u := r.PathPrefix("/shares").Subrouter() diff --git a/internal/adapters/handlers/rest/sharehdl/errors.go b/internal/adapters/handlers/rest/sharehdl/errors.go index 17bcb7b..4453940 100644 --- a/internal/adapters/handlers/rest/sharehdl/errors.go +++ b/internal/adapters/handlers/rest/sharehdl/errors.go @@ -2,7 +2,6 @@ package sharehdl import ( "errors" - "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/internal/applications/shareapp" ) @@ -28,6 +27,8 @@ func fromApplicationError(err error) *api.Error { return api.ErrEncryptionNotConfigured case errors.Is(err, shareapp.ErrInvalidEncryptionPart): return api.ErrInvalidEncryptionPart + case errors.Is(err, shareapp.ErrInvalidEncryptionSession): + return api.ErrInvalidEncryptionSession default: return api.ErrInternal } diff --git a/internal/adapters/handlers/rest/sharehdl/handler.go b/internal/adapters/handlers/rest/sharehdl/handler.go index 1a1f442..da6e1ce 100644 --- a/internal/adapters/handlers/rest/sharehdl/handler.go +++ b/internal/adapters/handlers/rest/sharehdl/handler.go @@ -71,6 +71,9 @@ func (h *Handler) RegisterShare(w http.ResponseWriter, r *http.Request) { if req.EncryptionPart != "" { opts = append(opts, shareapp.WithEncryptionPart(req.EncryptionPart)) } + if req.EncryptionSession != "" { + opts = append(opts, shareapp.WithEncryptionSession(req.EncryptionSession)) + } err = h.app.RegisterShare(ctx, share, opts...) if err != nil { api.RespondWithError(w, fromApplicationError(err)) @@ -134,6 +137,11 @@ func (h *Handler) GetShare(w http.ResponseWriter, r *http.Request) { opts = append(opts, shareapp.WithEncryptionPart(encryptionPart)) } + encryptionSession := r.Header.Get(EncryptionSessionHeader) + if encryptionSession != "" { + opts = append(opts, shareapp.WithEncryptionSession(encryptionSession)) + } + shr, err := h.app.GetShare(ctx, opts...) if err != nil { api.RespondWithError(w, fromApplicationError(err)) diff --git a/internal/adapters/handlers/rest/sharehdl/parser.go b/internal/adapters/handlers/rest/sharehdl/parser.go index fc1d43f..4489256 100644 --- a/internal/adapters/handlers/rest/sharehdl/parser.go +++ b/internal/adapters/handlers/rest/sharehdl/parser.go @@ -30,6 +30,10 @@ func (p *parser) toDomain(s *Share) *share.Share { }, } + if s.EncryptionPart != "" || s.EncryptionSession != "" { + shr.EncryptionParameters.Entropy = share.EntropyProject + } + if s.Salt != "" { shr.EncryptionParameters.Salt = s.Salt } diff --git a/internal/adapters/handlers/rest/sharehdl/types.go b/internal/adapters/handlers/rest/sharehdl/types.go index 9fafbc5..ceac378 100644 --- a/internal/adapters/handlers/rest/sharehdl/types.go +++ b/internal/adapters/handlers/rest/sharehdl/types.go @@ -1,15 +1,17 @@ package sharehdl const EncryptionPartHeader = "X-Encryption-Part" +const EncryptionSessionHeader = "X-Encryption-Session" type Share struct { - Secret string `json:"secret"` - Entropy Entropy `json:"entropy"` - Salt string `json:"salt,omitempty"` - Iterations int `json:"iterations,omitempty"` - Length int `json:"length,omitempty"` - Digest string `json:"digest,omitempty"` - EncryptionPart string `json:"encryption_part,omitempty"` + Secret string `json:"secret"` + Entropy Entropy `json:"entropy"` + Salt string `json:"salt,omitempty"` + Iterations int `json:"iterations,omitempty"` + Length int `json:"length,omitempty"` + Digest string `json:"digest,omitempty"` + EncryptionPart string `json:"encryption_part,omitempty"` + EncryptionSession string `json:"encryption_session,omitempty"` } type RegisterShareRequest Share diff --git a/internal/adapters/repositories/bunt/client.go b/internal/adapters/repositories/bunt/client.go index 37970aa..2a4715b 100644 --- a/internal/adapters/repositories/bunt/client.go +++ b/internal/adapters/repositories/bunt/client.go @@ -6,12 +6,22 @@ type Client struct { *buntdb.DB } +var singleton *buntdb.DB + func New() (*Client, error) { + if singleton != nil { + return &Client{ + DB: singleton, + }, nil + } + db, err := buntdb.Open(":memory:") if err != nil { return nil, err } + singleton = db + return &Client{ DB: db, }, nil diff --git a/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go b/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go index f4291ad..cbb09fe 100644 --- a/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go +++ b/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go @@ -50,21 +50,27 @@ func (r *repository) Get(ctx context.Context, sessionId string) (string, error) func (r *repository) Set(ctx context.Context, sessionId, part string) error { return r.db.Update(func(tx *buntdb.Tx) error { _, _, err := tx.Set(sessionId, part, nil) - if errors.Is(err, buntdb.ErrIndexExists) { - return domainErrors.ErrEncryptionPartAlreadyExists + if err != nil { + if errors.Is(err, buntdb.ErrIndexExists) { + return domainErrors.ErrEncryptionPartAlreadyExists + } + r.logger.ErrorContext(ctx, "error setting encryption part", logger.Error(err)) + return err } - r.logger.ErrorContext(ctx, "error setting encryption part", logger.Error(err)) - return err + + return nil }) } func (r *repository) Delete(ctx context.Context, sessionId string) error { return r.db.Update(func(tx *buntdb.Tx) error { _, err := tx.Delete(sessionId) - if errors.Is(err, buntdb.ErrNotFound) { - return domainErrors.ErrEncryptionPartNotFound + if err != nil { + if errors.Is(err, buntdb.ErrNotFound) { + return domainErrors.ErrEncryptionPartNotFound + } + r.logger.ErrorContext(ctx, "error deleting encryption part", logger.Error(err)) } - r.logger.ErrorContext(ctx, "error deleting encryption part", logger.Error(err)) return err }) } diff --git a/internal/adapters/repositories/sql/client.go b/internal/adapters/repositories/sql/client.go index f763098..e895481 100644 --- a/internal/adapters/repositories/sql/client.go +++ b/internal/adapters/repositories/sql/client.go @@ -2,13 +2,11 @@ package sql import ( "database/sql" - "fmt" - "path/filepath" - "github.com/pressly/goose" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/gorm" + "path/filepath" ) type Client struct { @@ -65,7 +63,6 @@ func newMySQL(cfg *Config) (gorm.Dialector, error) { func newCloudSQL(cfg *Config) (gorm.Dialector, error) { dsn := cfg.CloudSQLDSN() - fmt.Println("DSN: " + dsn) sqlDB, err := sql.Open("mysql", dsn) if err != nil { return nil, err diff --git a/internal/adapters/repositories/sql/config.go b/internal/adapters/repositories/sql/config.go index 30802b6..5c4ff85 100644 --- a/internal/adapters/repositories/sql/config.go +++ b/internal/adapters/repositories/sql/config.go @@ -33,7 +33,7 @@ type Config struct { UnixSocketPath string `env:"INSTANCE_UNIX_SOCKET"` } -const migrationDirectory = "internal/infrastructure/repositories/sql/migrations" +const migrationDirectory = "internal/adapters/repositories/sql/migrations" func GetConfigFromEnv() (*Config, error) { cfg := &Config{} diff --git a/internal/adapters/repositories/sql/projectrepo/repo.go b/internal/adapters/repositories/sql/projectrepo/repo.go index afb1464..26360f0 100644 --- a/internal/adapters/repositories/sql/projectrepo/repo.go +++ b/internal/adapters/repositories/sql/projectrepo/repo.go @@ -95,7 +95,7 @@ func (r *repository) GetEncryptionPart(ctx context.Context, projectID string) (s r.logger.InfoContext(ctx, "getting encryption part") encryptionPart := &EncryptionPart{} - err := r.db.Where("project_id = ?", projectID).First(encryptionPart).Error + err := r.db.Model(&EncryptionPart{}).Where("project_id = ?", projectID).First(encryptionPart).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return "", domainErrors.ErrEncryptionPartNotFound diff --git a/internal/applications/projectapp/app.go b/internal/applications/projectapp/app.go index d377855..bdb82a2 100644 --- a/internal/applications/projectapp/app.go +++ b/internal/applications/projectapp/app.go @@ -3,7 +3,10 @@ package projectapp import ( "context" "errors" + "github.com/google/uuid" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/pkg/random" "log/slog" "go.openfort.xyz/shield/internal/core/domain/project" @@ -12,27 +15,30 @@ import ( "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/pkg/contexter" - "go.openfort.xyz/shield/pkg/cypher" "go.openfort.xyz/shield/pkg/logger" ) type ProjectApplication struct { - projectSvc services.ProjectService - projectRepo repositories.ProjectRepository - providerSvc services.ProviderService - providerRepo repositories.ProviderRepository - sharesRepo repositories.ShareRepository - logger *slog.Logger + projectSvc services.ProjectService + projectRepo repositories.ProjectRepository + providerSvc services.ProviderService + providerRepo repositories.ProviderRepository + sharesRepo repositories.ShareRepository + logger *slog.Logger + encryptionFactory factories.EncryptionFactory + encryptionPartsRepo repositories.EncryptionPartsRepository } -func New(projectSvc services.ProjectService, projectRepo repositories.ProjectRepository, providerSvc services.ProviderService, providerRepo repositories.ProviderRepository, sharesRepo repositories.ShareRepository) *ProjectApplication { +func New(projectSvc services.ProjectService, projectRepo repositories.ProjectRepository, providerSvc services.ProviderService, providerRepo repositories.ProviderRepository, sharesRepo repositories.ShareRepository, encryptionFactory factories.EncryptionFactory, encryptionPartsRepo repositories.EncryptionPartsRepository) *ProjectApplication { return &ProjectApplication{ - projectSvc: projectSvc, - projectRepo: projectRepo, - providerSvc: providerSvc, - providerRepo: providerRepo, - sharesRepo: sharesRepo, - logger: logger.New("project_application"), + projectSvc: projectSvc, + projectRepo: projectRepo, + providerSvc: providerSvc, + providerRepo: providerRepo, + sharesRepo: sharesRepo, + logger: logger.New("project_application"), + encryptionFactory: encryptionFactory, + encryptionPartsRepo: encryptionPartsRepo, } } @@ -267,13 +273,25 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP a.logger.InfoContext(ctx, "encrypting project shares") projectID := contexter.GetProjectID(ctx) - storedPart, err := a.projectRepo.GetEncryptionPart(ctx, projectID) + builder, err := a.encryptionFactory.CreateEncryptionKeyBuilder(factories.Plain) + if err != nil { + a.logger.ErrorContext(ctx, "failed to create encryption key builder", logger.Error(err)) + return ErrInternal + } + + err = builder.SetDatabasePart(ctx, projectID) if err != nil { a.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) return fromDomainError(err) } - encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, externalPart) + err = builder.SetProjectPart(ctx, externalPart) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) + return fromDomainError(err) + } + + encryptionKey, err := builder.Build(ctx) if err != nil { a.logger.ErrorContext(ctx, "failed to reconstruct encryption key", logger.Error(err)) return ErrInvalidEncryptionPart @@ -291,7 +309,8 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP continue } - shr.Secret, err = cypher.Encrypt(shr.Secret, encryptionKey) + cypher := a.encryptionFactory.CreateEncryptionStrategy(encryptionKey) + shr.Secret, err = cypher.Encrypt(shr.Secret) if err != nil { a.logger.ErrorContext(ctx, "failed to encrypt share", logger.Error(err)) return fromDomainError(err) @@ -315,6 +334,19 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP return nil } +func (a *ProjectApplication) RegisterEncryptionSession(ctx context.Context, encryptionPart string) (string, error) { + a.logger.InfoContext(ctx, "registering encryption session") + + sessionID := uuid.NewString() + err := a.encryptionPartsRepo.Set(ctx, sessionID, encryptionPart) + if err != nil { + a.logger.ErrorContext(ctx, "failed to set encryption part", logger.Error(err)) + return "", fromDomainError(err) + } + + return sessionID, nil +} + func (a *ProjectApplication) RegisterEncryptionKey(ctx context.Context) (string, error) { a.logger.InfoContext(ctx, "registering encryption key") projectID := contexter.GetProjectID(ctx) @@ -340,16 +372,27 @@ func (a *ProjectApplication) RegisterEncryptionKey(ctx context.Context) (string, } func (a *ProjectApplication) registerEncryptionKey(ctx context.Context, projectID string) (externalPart string, err error) { - var shieldPart string - shieldPart, externalPart, err = cypher.GenerateEncryptionKey() + key, err := random.GenerateRandomString(32) if err != nil { - return "", err + a.logger.Error("failed to generate random key", logger.Error(err)) + return "", ErrInternal } - err = a.projectSvc.SetEncryptionPart(ctx, projectID, shieldPart) + reconstructionStrategy := a.encryptionFactory.CreateReconstructionStrategy() + parts, err := reconstructionStrategy.Split(key) + if err != nil { + a.logger.Error("failed to split encryption key", logger.Error(err)) + return "", ErrInternal + } + if len(parts) != 2 { + a.logger.Error("invalid encryption key parts", slog.Int("parts", len(parts))) + return "", ErrInternal + } + + err = a.projectSvc.SetEncryptionPart(ctx, projectID, parts[0]) if err != nil { return "", err } - return externalPart, nil + return parts[1], nil } diff --git a/internal/applications/projectapp/errors.go b/internal/applications/projectapp/errors.go index 1685240..ec59f36 100644 --- a/internal/applications/projectapp/errors.go +++ b/internal/applications/projectapp/errors.go @@ -15,6 +15,7 @@ var ( ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project") ErrProviderNotFound = errors.New("custom authentication not found") ErrInvalidEncryptionPart = errors.New("invalid encryption part") + ErrInvalidEncryptionSession = errors.New("invalid encryption session") ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") ErrEncryptionNotConfigured = errors.New("encryption not configured") ErrJWKPemConflict = errors.New("jwk and pem cannot be set at the same time") @@ -49,5 +50,8 @@ func fromDomainError(err error) error { return ErrEncryptionNotConfigured } + if errors.Is(err, domainErrors.ErrInvalidEncryptionSession) { + return ErrInvalidEncryptionSession + } return ErrInternal } diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go index 5d2b186..43a0747 100644 --- a/internal/applications/shareapp/app.go +++ b/internal/applications/shareapp/app.go @@ -2,6 +2,7 @@ package shareapp import ( "context" + "go.openfort.xyz/shield/internal/core/ports/factories" "log/slog" "go.openfort.xyz/shield/internal/core/domain/share" @@ -13,18 +14,20 @@ import ( ) type ShareApplication struct { - shareSvc services.ShareService - shareRepo repositories.ShareRepository - projectRepo repositories.ProjectRepository - logger *slog.Logger + shareSvc services.ShareService + shareRepo repositories.ShareRepository + projectRepo repositories.ProjectRepository + logger *slog.Logger + encryptionFactory factories.EncryptionFactory } -func New(shareSvc services.ShareService, shareRepo repositories.ShareRepository, projectRepo repositories.ProjectRepository) *ShareApplication { +func New(shareSvc services.ShareService, shareRepo repositories.ShareRepository, projectRepo repositories.ProjectRepository, encryptionFactory factories.EncryptionFactory) *ShareApplication { return &ShareApplication{ - shareSvc: shareSvc, - shareRepo: shareRepo, - projectRepo: projectRepo, - logger: logger.New("share_application"), + shareSvc: shareSvc, + shareRepo: shareRepo, + projectRepo: projectRepo, + logger: logger.New("share_application"), + encryptionFactory: encryptionFactory, } } @@ -41,10 +44,6 @@ func (a *ShareApplication) RegisterShare(ctx context.Context, shr *share.Share, var shrOpts []services.ShareOption if shr.RequiresEncryption() { - if opt.encryptionPart == nil { - return ErrEncryptionPartRequired - } - encryptionKey, err := a.reconstructEncryptionKey(ctx, projID, opt) if err != nil { return err @@ -114,20 +113,42 @@ func (a *ShareApplication) DeleteShare(ctx context.Context) error { } func (a *ShareApplication) reconstructEncryptionKey(ctx context.Context, projID string, opt options) (string, error) { - if opt.encryptionPart == nil || *opt.encryptionPart == "" { + var builderType factories.EncryptionKeyBuilderType + var identifier string + switch { + case opt.encryptionPart != nil && *opt.encryptionPart != "": + builderType = factories.Plain + identifier = *opt.encryptionPart + case opt.encryptionSession != nil && *opt.encryptionSession != "": + builderType = factories.Session + identifier = *opt.encryptionSession + default: return "", ErrEncryptionPartRequired } - storedPart, err := a.projectRepo.GetEncryptionPart(ctx, projID) + builder, err := a.encryptionFactory.CreateEncryptionKeyBuilder(builderType) + if err != nil { + a.logger.ErrorContext(ctx, "failed to create encryption key builder", logger.Error(err)) + return "", ErrInternal + } + + err = builder.SetDatabasePart(ctx, projID) if err != nil { - a.logger.ErrorContext(ctx, "failed to get encryption part", logger.Error(err)) + a.logger.ErrorContext(ctx, "failed to get database encryption part", logger.Error(err)) return "", fromDomainError(err) } - encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, *opt.encryptionPart) + err = builder.SetProjectPart(ctx, identifier) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get project encryption part", logger.Error(err)) + return "", fromDomainError(err) + } + + encryptionKey, err := builder.Build(ctx) if err != nil { a.logger.ErrorContext(ctx, "failed to reconstruct encryption key", logger.Error(err)) return "", ErrInvalidEncryptionPart } + return encryptionKey, nil } diff --git a/internal/applications/shareapp/errors.go b/internal/applications/shareapp/errors.go index b24c123..da14008 100644 --- a/internal/applications/shareapp/errors.go +++ b/internal/applications/shareapp/errors.go @@ -14,6 +14,7 @@ var ( ErrEncryptionPartRequired = errors.New("encryption part is required") ErrEncryptionNotConfigured = errors.New("encryption not configured") ErrInvalidEncryptionPart = errors.New("invalid encryption part") + ErrInvalidEncryptionSession = errors.New("invalid encryption session") ErrInternal = errors.New("internal error") ) @@ -34,5 +35,8 @@ func fromDomainError(err error) error { return ErrEncryptionNotConfigured } + if errors.Is(err, domainErrors.ErrInvalidEncryptionSession) { + return ErrInvalidEncryptionSession + } return ErrInternal } diff --git a/internal/core/domain/errors/project.go b/internal/core/domain/errors/project.go index b6601e0..6b9eeca 100644 --- a/internal/core/domain/errors/project.go +++ b/internal/core/domain/errors/project.go @@ -7,4 +7,5 @@ var ( ErrEncryptionPartNotFound = errors.New("encryption part not found") ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") ErrEncryptionPartRequired = errors.New("encryption part is required") + ErrInvalidEncryptionSession = errors.New("invalid encryption session") ) diff --git a/internal/core/domain/errors/provider.go b/internal/core/domain/errors/provider.go index c770ede..287bd3a 100644 --- a/internal/core/domain/errors/provider.go +++ b/internal/core/domain/errors/provider.go @@ -3,8 +3,13 @@ package errors import "errors" var ( - ErrInvalidProviderConfig = errors.New("invalid provider config") - ErrUnknownProviderType = errors.New("unknown provider type") - ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project") - ErrProviderNotFound = errors.New("custom authentication not found") + ErrInvalidProviderConfig = errors.New("invalid provider config") + ErrUnknownProviderType = errors.New("unknown provider type") + ErrProviderAlreadyExists = errors.New("custom authentication already registered for this project") + ErrProviderNotFound = errors.New("custom authentication not found") + ErrProviderNotConfigured = errors.New("provider not configured") + ErrProviderConfigMismatch = errors.New("provider config mismatch") + ErrUnexpectedStatusCode = errors.New("unexpected status code") + ErrCertTypeNotSupported = errors.New("certificate type not supported") + ErrProviderMisconfigured = errors.New("provider misconfigured") ) diff --git a/internal/core/ports/builders/encryption.go b/internal/core/ports/builders/encryption.go new file mode 100644 index 0000000..02abd77 --- /dev/null +++ b/internal/core/ports/builders/encryption.go @@ -0,0 +1,11 @@ +package builders + +import ( + "context" +) + +type EncryptionKeyBuilder interface { + SetProjectPart(ctx context.Context, identifier string) error + SetDatabasePart(ctx context.Context, identifier string) error + Build(ctx context.Context) (string, error) +} diff --git a/internal/core/ports/factories/encryption.go b/internal/core/ports/factories/encryption.go index ead4c0f..e4ba092 100644 --- a/internal/core/ports/factories/encryption.go +++ b/internal/core/ports/factories/encryption.go @@ -1,75 +1,19 @@ package factories import ( - "context" - "errors" - "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/internal/core/ports/builders" + "go.openfort.xyz/shield/internal/core/ports/strategies" ) type EncryptionFactory interface { - CreateEncryptionStrategy() EncryptionStrategy + CreateEncryptionKeyBuilder(builderType EncryptionKeyBuilderType) (builders.EncryptionKeyBuilder, error) + CreateReconstructionStrategy() strategies.ReconstructionStrategy + CreateEncryptionStrategy(key string) strategies.EncryptionStrategy } -type EncryptionStrategy interface { - Encrypt(ctx context.Context, plain string) (string, error) - Decrypt(ctx context.Context, encrypted string) (string, error) -} - -type EncryptionKeyBuilder interface { - SetEncryptionPart(ctx context.Context, part string) EncryptionKeyBuilder - SetSessionPart(ctx context.Context, sessionID string) (EncryptionKeyBuilder, error) - SetDatabasePart(ctx context.Context, projectID string) (EncryptionKeyBuilder, error) - Build(ctx context.Context) (string, error) -} - -type EncryptionKeyBuilderImpl struct { - projectPart string - databasePart string - encryptionPartsRepo repositories.EncryptionPartsRepository - projectRepo repositories.ProjectRepository -} - -func NewEncryptionKeyBuilder() EncryptionKeyBuilder { - return &EncryptionKeyBuilderImpl{ - projectPart: "", - databasePart: "", - } -} - -func (b *EncryptionKeyBuilderImpl) SetEncryptionPart(ctx context.Context, part string) EncryptionKeyBuilder { - b.projectPart = part - return b -} - -func (b *EncryptionKeyBuilderImpl) SetSessionPart(ctx context.Context, sessionID string) (EncryptionKeyBuilder, error) { - part, err := b.encryptionPartsRepo.Get(ctx, sessionID) - if err != nil { - return nil, err - } +type EncryptionKeyBuilderType int8 - b.projectPart = part - return b, nil -} - -func (b *EncryptionKeyBuilderImpl) SetDatabasePart(ctx context.Context, projectID string) (EncryptionKeyBuilder, error) { - part, err := b.projectRepo.GetEncryptionPart(ctx, projectID) - if err != nil { - return nil, err - } - - b.databasePart = part - return b, nil -} - -func (b *EncryptionKeyBuilderImpl) Build(ctx context.Context) (string, error) { - if b.projectPart == "" { - return "", errors.New("project part is required") // TODO extract error - } - - if b.databasePart == "" { - return "", errors.New("database part is required") // TODO extract error - } - - return cypher.ReconstructEncryptionKey(b.projectPart, b.databasePart) -} +const ( + Plain EncryptionKeyBuilderType = iota + Session +) diff --git a/internal/core/ports/repositories/encryptionparts.go b/internal/core/ports/repositories/encryption_parts.go similarity index 100% rename from internal/core/ports/repositories/encryptionparts.go rename to internal/core/ports/repositories/encryption_parts.go diff --git a/internal/core/ports/strategies/encryption.go b/internal/core/ports/strategies/encryption.go new file mode 100644 index 0000000..5a9aa69 --- /dev/null +++ b/internal/core/ports/strategies/encryption.go @@ -0,0 +1,6 @@ +package strategies + +type EncryptionStrategy interface { + Encrypt(data string) (string, error) + Decrypt(data string) (string, error) +} diff --git a/internal/core/ports/strategies/reconstruction.go b/internal/core/ports/strategies/reconstruction.go new file mode 100644 index 0000000..c5652e0 --- /dev/null +++ b/internal/core/ports/strategies/reconstruction.go @@ -0,0 +1,6 @@ +package strategies + +type ReconstructionStrategy interface { + Split(data string) ([]string, error) + Reconstruct(parts []string) (string, error) +} diff --git a/internal/core/services/sharesvc/svc.go b/internal/core/services/sharesvc/svc.go index bb8d7ba..eaa3bdf 100644 --- a/internal/core/services/sharesvc/svc.go +++ b/internal/core/services/sharesvc/svc.go @@ -4,26 +4,28 @@ import ( "context" "errors" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/factories" "log/slog" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" - "go.openfort.xyz/shield/pkg/cypher" "go.openfort.xyz/shield/pkg/logger" ) type service struct { - repo repositories.ShareRepository - logger *slog.Logger + repo repositories.ShareRepository + logger *slog.Logger + encryptionFactory factories.EncryptionFactory } var _ services.ShareService = (*service)(nil) -func New(repo repositories.ShareRepository) services.ShareService { +func New(repo repositories.ShareRepository, encryptionFactory factories.EncryptionFactory) services.ShareService { return &service{ - repo: repo, - logger: logger.New("share_service"), + repo: repo, + logger: logger.New("share_service"), + encryptionFactory: encryptionFactory, } } @@ -51,7 +53,8 @@ func (s *service) Create(ctx context.Context, shr *share.Share, opts ...services return domainErrors.ErrEncryptionPartRequired } - shr.Secret, err = cypher.Encrypt(shr.Secret, *o.EncryptionKey) + cypher := s.encryptionFactory.CreateEncryptionStrategy(*o.EncryptionKey) + shr.Secret, err = cypher.Encrypt(shr.Secret) if err != nil { s.logger.ErrorContext(ctx, "failed to encrypt secret", logger.Error(err)) return err diff --git a/pkg/cypher/cypher.go b/pkg/cypher/cypher.go index 261a006..7cd9dd4 100644 --- a/pkg/cypher/cypher.go +++ b/pkg/cypher/cypher.go @@ -3,31 +3,12 @@ package cypher import ( "crypto/aes" "crypto/cipher" - "crypto/rand" "encoding/base64" "errors" - "io" - "github.com/codahale/sss" + "go.openfort.xyz/shield/pkg/random" ) -func generateRandomBytes(n int) ([]byte, error) { - b := make([]byte, n) - _, err := io.ReadFull(rand.Reader, b) - if err != nil { - return nil, err - } - return b, nil -} - -func generateRandomString(n int) (string, error) { - b, err := generateRandomBytes(n) - if err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(b), nil -} - func Encrypt(plaintext, key string) (string, error) { keyBytes, err := base64.StdEncoding.DecodeString(key) if err != nil { @@ -44,7 +25,7 @@ func Encrypt(plaintext, key string) (string, error) { return "", err } - nonce, err := generateRandomBytes(aesGCM.NonceSize()) + nonce, err := random.GenerateRandomBytes(aesGCM.NonceSize()) if err != nil { return "", err } @@ -88,16 +69,7 @@ func Decrypt(encrypted, key string) (string, error) { return string(plaintext), nil } -func GenerateEncryptionKey() (string, string, error) { - key, err := generateRandomString(32) - if err != nil { - return "", "", err - } - - return splitKey(key) -} - -func splitKey(key string) (string, string, error) { +func SplitEncryptionKey(key string) (string, string, error) { rawKey, err := base64.StdEncoding.DecodeString(key) if err != nil { return "", "", err diff --git a/pkg/random/random.go b/pkg/random/random.go new file mode 100644 index 0000000..82c9a0f --- /dev/null +++ b/pkg/random/random.go @@ -0,0 +1,24 @@ +package random + +import ( + "crypto/rand" + "encoding/base64" + "io" +) + +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return nil, err + } + return b, nil +} + +func GenerateRandomString(n int) (string, error) { + b, err := GenerateRandomBytes(n) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(b), nil +} From 0dfe7ec0812743a42f4f907a8f9a9e3326526622 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Tue, 9 Jul 2024 17:48:57 +0200 Subject: [PATCH 04/10] feat: update share --- .../handlers/rest/sharehdl/handler.go | 63 +++++++++++++++++++ .../adapters/handlers/rest/sharehdl/types.go | 2 + .../repositories/mocks/sharemockrepo/repo.go | 5 ++ .../repositories/sql/sharerepo/repo.go | 13 ++++ internal/applications/shareapp/app.go | 51 ++++++++++++++- internal/core/ports/repositories/shares.go | 1 + 6 files changed, 133 insertions(+), 2 deletions(-) diff --git a/internal/adapters/handlers/rest/sharehdl/handler.go b/internal/adapters/handlers/rest/sharehdl/handler.go index da6e1ce..255505c 100644 --- a/internal/adapters/handlers/rest/sharehdl/handler.go +++ b/internal/adapters/handlers/rest/sharehdl/handler.go @@ -83,6 +83,69 @@ func (h *Handler) RegisterShare(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) } +// UpdateShare updates a share +// @Summary Update share +// @Description Update a share for the user +// @Tags Share +// @Accept json +// @Produce json +// @Param X-API-Key header string true "API Key" +// @Param Authorization header string true "Bearer token" +// @Param X-Auth-Provider header string true "Auth Provider" +// @Param X-Openfort-Provider header string false "Openfort Provider" +// @Param X-Openfort-Token-Type header string false "Openfort Token Type" +// @Param updateShareRequest body UpdateShareRequest true "Update Share Request" +// @Success 200 {object} UpdateShareResponse "Successful response" +// @Failure 400 {object} api.Error "Bad Request" +// @Failure 404 {object} api.Error "Not Found" +// @Failure 500 {object} api.Error "Internal Server Error" +// @Router /shares [put] +func (h *Handler) UpdateShare(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + h.logger.InfoContext(ctx, "updating share") + + body, err := io.ReadAll(r.Body) + if err != nil { + api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to read request body")) + return + } + + var req UpdateShareRequest + err = json.Unmarshal(body, &req) + if err != nil { + api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to parse request body")) + return + } + + if errV := h.validator.validateShare((*Share)(&req)); errV != nil { + api.RespondWithError(w, errV) + return + } + + share := h.parser.toDomain((*Share)(&req)) + var opts []shareapp.Option + if req.EncryptionPart != "" { + opts = append(opts, shareapp.WithEncryptionPart(req.EncryptionPart)) + } + if req.EncryptionSession != "" { + opts = append(opts, shareapp.WithEncryptionSession(req.EncryptionSession)) + } + shr, err := h.app.UpdateShare(ctx, share, opts...) + if err != nil { + api.RespondWithError(w, fromApplicationError(err)) + return + } + + resp, err := json.Marshal(UpdateShareResponse(*h.parser.fromDomain(shr))) + if err != nil { + api.RespondWithError(w, api.ErrInternal) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write(resp) +} + // DeleteShare deletes a share // @Summary Delete share // @Description Delete a share for the user diff --git a/internal/adapters/handlers/rest/sharehdl/types.go b/internal/adapters/handlers/rest/sharehdl/types.go index ceac378..c535d63 100644 --- a/internal/adapters/handlers/rest/sharehdl/types.go +++ b/internal/adapters/handlers/rest/sharehdl/types.go @@ -16,6 +16,8 @@ type Share struct { type RegisterShareRequest Share type GetShareResponse Share +type UpdateShareRequest Share +type UpdateShareResponse Share type Entropy string diff --git a/internal/adapters/repositories/mocks/sharemockrepo/repo.go b/internal/adapters/repositories/mocks/sharemockrepo/repo.go index 48df296..3db6ec7 100644 --- a/internal/adapters/repositories/mocks/sharemockrepo/repo.go +++ b/internal/adapters/repositories/mocks/sharemockrepo/repo.go @@ -44,3 +44,8 @@ func (m *MockShareRepository) UpdateProjectEncryption(ctx context.Context, share args := m.Mock.Called(ctx, shareID, encrypted) return args.Error(0) } + +func (m *MockShareRepository) Update(ctx context.Context, shr *share.Share) error { + args := m.Mock.Called(ctx, shr) + return args.Error(0) +} diff --git a/internal/adapters/repositories/sql/sharerepo/repo.go b/internal/adapters/repositories/sql/sharerepo/repo.go index 6da1931..2d61ab5 100644 --- a/internal/adapters/repositories/sql/sharerepo/repo.go +++ b/internal/adapters/repositories/sql/sharerepo/repo.go @@ -108,3 +108,16 @@ func (r *repository) UpdateProjectEncryption(ctx context.Context, shareID string return nil } + +func (r *repository) Update(ctx context.Context, shr *share.Share) error { + r.logger.InfoContext(ctx, "updating share", slog.String("id", shr.ID)) + + dbShr := r.parser.toDatabase(shr) + err := r.db.Model(&Share{}).Where("id = ?", shr.ID).Updates(dbShr).Error + if err != nil { + r.logger.ErrorContext(ctx, "error updating share", logger.Error(err)) + return err + } + + return nil +} diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go index 43a0747..71a13f0 100644 --- a/internal/applications/shareapp/app.go +++ b/internal/applications/shareapp/app.go @@ -9,7 +9,6 @@ import ( "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/pkg/contexter" - "go.openfort.xyz/shield/pkg/cypher" "go.openfort.xyz/shield/pkg/logger" ) @@ -61,6 +60,53 @@ func (a *ShareApplication) RegisterShare(ctx context.Context, shr *share.Share, return nil } +func (a *ShareApplication) UpdateShare(ctx context.Context, shr *share.Share, opts ...Option) (*share.Share, error) { + a.logger.InfoContext(ctx, "updating share") + usrID := contexter.GetUserID(ctx) + projID := contexter.GetProjectID(ctx) + + dbShare, err := a.shareRepo.GetByUserID(ctx, usrID) + if err != nil { + a.logger.ErrorContext(ctx, "failed to get share by user ID", logger.Error(err)) + return nil, fromDomainError(err) + } + + if shr.EncryptionParameters != nil { + dbShare.EncryptionParameters = shr.EncryptionParameters + } + + if shr.Secret != "" { + dbShare.Secret = shr.Secret + } + + var opt options + for _, o := range opts { + o(&opt) + } + + if dbShare.RequiresEncryption() { + encryptionKey, err := a.reconstructEncryptionKey(ctx, projID, opt) + if err != nil { + return nil, err + } + + cypher := a.encryptionFactory.CreateEncryptionStrategy(encryptionKey) + dbShare.Secret, err = cypher.Encrypt(dbShare.Secret) + if err != nil { + a.logger.ErrorContext(ctx, "failed to encrypt secret", logger.Error(err)) + return nil, ErrInternal + } + } + + err = a.shareRepo.Update(ctx, dbShare) + if err != nil { + a.logger.ErrorContext(ctx, "failed to create share", logger.Error(err)) + return nil, fromDomainError(err) + } + + return shr, nil +} + func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share.Share, error) { a.logger.InfoContext(ctx, "getting share") usrID := contexter.GetUserID(ctx) @@ -83,7 +129,8 @@ func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share return nil, err } - shr.Secret, err = cypher.Decrypt(shr.Secret, encryptionKey) + cypher := a.encryptionFactory.CreateEncryptionStrategy(encryptionKey) + shr.Secret, err = cypher.Decrypt(shr.Secret) if err != nil { a.logger.ErrorContext(ctx, "failed to decrypt secret", logger.Error(err)) return nil, ErrInternal diff --git a/internal/core/ports/repositories/shares.go b/internal/core/ports/repositories/shares.go index e1d96e5..4f681b7 100644 --- a/internal/core/ports/repositories/shares.go +++ b/internal/core/ports/repositories/shares.go @@ -12,4 +12,5 @@ type ShareRepository interface { Delete(ctx context.Context, shareID string) error ListDecryptedByProjectID(ctx context.Context, projectID string) ([]*share.Share, error) UpdateProjectEncryption(ctx context.Context, shareID string, encrypted string) error + Update(ctx context.Context, shr *share.Share) error } From e5508cf97e343758d417adb0e03822a9235a7c41 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Wed, 10 Jul 2024 13:50:21 +0200 Subject: [PATCH 05/10] fix: tests --- go.mod | 4 +- go.sum | 2 + internal/adapters/encryption/factory.go | 4 +- .../encryption/plain_builder/builder.go | 16 +-- .../encryption/session_builder/builder.go | 20 ++-- .../sss_reconstruction_strategy/strategy.go | 18 +--- internal/applications/projectapp/app.go | 10 +- internal/applications/projectapp/app_test.go | 58 +++++++--- internal/applications/shareapp/app.go | 4 + internal/applications/shareapp/app_test.go | 101 ++++++++++++------ .../core/ports/strategies/reconstruction.go | 4 +- internal/core/services/sharesvc/svc_test.go | 20 +++- internal/core/services/usersvc/svc_test.go | 16 ++- pkg/cypher/cypher.go | 16 ++- 14 files changed, 191 insertions(+), 102 deletions(-) diff --git a/go.mod b/go.mod index c011d6f..b5c9fb8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.openfort.xyz/shield -go 1.22.0 +go 1.22.5 require ( github.com/MicahParks/keyfunc/v3 v3.3.3 @@ -16,7 +16,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/buntdb v1.3.1 go.uber.org/ratelimit v0.3.1 - golang.org/x/crypto v0.24.0 + golang.org/x/crypto v0.25.0 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.9 gorm.io/gorm v1.25.10 diff --git a/go.sum b/go.sum index 97524be..6a25aba 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= diff --git a/internal/adapters/encryption/factory.go b/internal/adapters/encryption/factory.go index 69efc85..efdd2c2 100644 --- a/internal/adapters/encryption/factory.go +++ b/internal/adapters/encryption/factory.go @@ -27,9 +27,9 @@ func NewEncryptionFactory(encryptionPartsRepo repositories.EncryptionPartsReposi func (e *encryptionFactory) CreateEncryptionKeyBuilder(builderType factories.EncryptionKeyBuilderType) (builders.EncryptionKeyBuilder, error) { switch builderType { case factories.Plain: - return plain_builder.NewEncryptionKeyBuilder(e.projectRepo), nil + return plain_builder.NewEncryptionKeyBuilder(e.projectRepo, sss_reconstruction_strategy.NewSSSReconstructionStrategy()), nil case factories.Session: - return session_builder.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo), nil + return session_builder.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo, sss_reconstruction_strategy.NewSSSReconstructionStrategy()), nil } return nil, errors.New("invalid builder type") //TODO extract error diff --git a/internal/adapters/encryption/plain_builder/builder.go b/internal/adapters/encryption/plain_builder/builder.go index c1bbcf1..cf029c9 100644 --- a/internal/adapters/encryption/plain_builder/builder.go +++ b/internal/adapters/encryption/plain_builder/builder.go @@ -5,18 +5,20 @@ import ( "errors" "go.openfort.xyz/shield/internal/core/ports/builders" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/internal/core/ports/strategies" ) type plainBuilder struct { - projectPart string - databasePart string - projectRepo repositories.ProjectRepository + projectPart string + databasePart string + projectRepo repositories.ProjectRepository + reconstructionStrategy strategies.ReconstructionStrategy } -func NewEncryptionKeyBuilder(repo repositories.ProjectRepository) builders.EncryptionKeyBuilder { +func NewEncryptionKeyBuilder(repo repositories.ProjectRepository, reconstructionStrategy strategies.ReconstructionStrategy) builders.EncryptionKeyBuilder { return &plainBuilder{ - projectRepo: repo, + projectRepo: repo, + reconstructionStrategy: reconstructionStrategy, } } @@ -44,5 +46,5 @@ func (b *plainBuilder) Build(ctx context.Context) (string, error) { return "", errors.New("database part is required") // TODO extract error } - return cypher.ReconstructEncryptionKey(b.projectPart, b.databasePart) + return b.reconstructionStrategy.Reconstruct(b.databasePart, b.projectPart) } diff --git a/internal/adapters/encryption/session_builder/builder.go b/internal/adapters/encryption/session_builder/builder.go index ed455ec..fd86f90 100644 --- a/internal/adapters/encryption/session_builder/builder.go +++ b/internal/adapters/encryption/session_builder/builder.go @@ -6,20 +6,22 @@ import ( domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/builders" "go.openfort.xyz/shield/internal/core/ports/repositories" - "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/internal/core/ports/strategies" ) type sessionBuilder struct { - projectPart string - databasePart string - encryptionPartsRepo repositories.EncryptionPartsRepository - projectRepo repositories.ProjectRepository + projectPart string + databasePart string + encryptionPartsRepo repositories.EncryptionPartsRepository + projectRepo repositories.ProjectRepository + reconstructionStrategy strategies.ReconstructionStrategy } -func NewEncryptionKeyBuilder(encryptionPartsRepo repositories.EncryptionPartsRepository, projectRepository repositories.ProjectRepository) builders.EncryptionKeyBuilder { +func NewEncryptionKeyBuilder(encryptionPartsRepo repositories.EncryptionPartsRepository, projectRepository repositories.ProjectRepository, reconstructionStrategy strategies.ReconstructionStrategy) builders.EncryptionKeyBuilder { return &sessionBuilder{ - encryptionPartsRepo: encryptionPartsRepo, - projectRepo: projectRepository, + encryptionPartsRepo: encryptionPartsRepo, + projectRepo: projectRepository, + reconstructionStrategy: reconstructionStrategy, } } @@ -60,5 +62,5 @@ func (b *sessionBuilder) Build(ctx context.Context) (string, error) { return "", errors.New("database part is required") // TODO extract error } - return cypher.ReconstructEncryptionKey(b.projectPart, b.databasePart) + return b.reconstructionStrategy.Reconstruct(b.databasePart, b.projectPart) } diff --git a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go index c8da2b4..75e04de 100644 --- a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go +++ b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go @@ -1,7 +1,6 @@ package sss_reconstruction_strategy import ( - "errors" "go.openfort.xyz/shield/internal/core/ports/strategies" "go.openfort.xyz/shield/pkg/cypher" ) @@ -12,19 +11,10 @@ func NewSSSReconstructionStrategy() strategies.ReconstructionStrategy { return &SSSReconstructionStrategy{} } -func (s *SSSReconstructionStrategy) Split(data string) ([]string, error) { - firstPart, secondPart, err := cypher.SplitEncryptionKey(data) - if err != nil { - return nil, err - } - - return []string{firstPart, secondPart}, nil +func (s *SSSReconstructionStrategy) Split(data string) (storedPart string, projectPart string, err error) { + return cypher.SplitEncryptionKey(data) } -func (s *SSSReconstructionStrategy) Reconstruct(parts []string) (string, error) { - if len(parts) != 2 { - return "", errors.New("invalid number of parts") //TODO extract error - } - - return cypher.ReconstructEncryptionKey(parts[0], parts[1]) +func (s *SSSReconstructionStrategy) Reconstruct(storedPart string, projectPart string) (string, error) { + return cypher.ReconstructEncryptionKey(storedPart, projectPart) } diff --git a/internal/applications/projectapp/app.go b/internal/applications/projectapp/app.go index bdb82a2..7a13054 100644 --- a/internal/applications/projectapp/app.go +++ b/internal/applications/projectapp/app.go @@ -379,20 +379,16 @@ func (a *ProjectApplication) registerEncryptionKey(ctx context.Context, projectI } reconstructionStrategy := a.encryptionFactory.CreateReconstructionStrategy() - parts, err := reconstructionStrategy.Split(key) + storedPart, projectPart, err := reconstructionStrategy.Split(key) if err != nil { a.logger.Error("failed to split encryption key", logger.Error(err)) return "", ErrInternal } - if len(parts) != 2 { - a.logger.Error("invalid encryption key parts", slog.Int("parts", len(parts))) - return "", ErrInternal - } - err = a.projectSvc.SetEncryptionPart(ctx, projectID, parts[0]) + err = a.projectSvc.SetEncryptionPart(ctx, projectID, storedPart) if err != nil { return "", err } - return parts[1], nil + return projectPart, nil } diff --git a/internal/applications/projectapp/app_test.go b/internal/applications/projectapp/app_test.go index a4727d9..766110e 100644 --- a/internal/applications/projectapp/app_test.go +++ b/internal/applications/projectapp/app_test.go @@ -5,6 +5,8 @@ import ( "errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.openfort.xyz/shield/internal/adapters/encryption" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/encryptionpartsmockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/providermockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/sharemockrepo" @@ -15,7 +17,7 @@ import ( "go.openfort.xyz/shield/internal/core/services/projectsvc" "go.openfort.xyz/shield/internal/core/services/providersvc" "go.openfort.xyz/shield/pkg/contexter" - "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/pkg/random" "testing" ) @@ -27,7 +29,9 @@ func TestProjectApplication_CreateProject(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) tc := []struct { name string @@ -130,7 +134,9 @@ func TestProjectApplication_GetProject(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) projOK := &project.Project{ ID: "project-id", Name: "project name", @@ -193,7 +199,9 @@ func TestProjectApplication_AddProviders(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) tc := []struct { name string @@ -363,7 +371,9 @@ func TestProjectApplication_GetProviders(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) providers := []*provider.Provider{ { ID: "provider-id", @@ -433,7 +443,9 @@ func TestProjectApplication_GetProviderDetail(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) prov := &provider.Provider{ ID: "provider-id", @@ -517,7 +529,9 @@ func TestProjectApplication_UpdateProvider(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) openfortProvider := &provider.Provider{ ID: "provider-id", @@ -729,7 +743,9 @@ func TestProjectApplication_RemoveProvider(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) openfortProvider := &provider.Provider{ ID: "provider-id", @@ -819,9 +835,17 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) - storedPart, externalPart, err := cypher.GenerateEncryptionKey() + key, err := random.GenerateRandomString(32) + if err != nil { + t.Fatalf(key) + } + + reconstructor := encryptionFactory.CreateReconstructionStrategy() + storedPart, projectPart, err := reconstructor.Split(key) if err != nil { t.Fatalf("failed to generate encryption key: %v", err) } @@ -865,7 +889,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { }{ { name: "success", - externalPart: externalPart, + externalPart: projectPart, mock: func() { projectRepo.ExpectedCalls = nil shareRepo.ExpectedCalls = nil @@ -877,7 +901,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { }, { name: "encryption part not found", - externalPart: externalPart, + externalPart: projectPart, mock: func() { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", domainErrors.ErrEncryptionPartNotFound) @@ -886,7 +910,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { }, { name: "error getting encryption part", - externalPart: externalPart, + externalPart: projectPart, mock: func() { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return("", errors.New("repository error")) @@ -904,7 +928,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { }, { name: "error listing shares", - externalPart: externalPart, + externalPart: projectPart, mock: func() { projectRepo.ExpectedCalls = nil shareRepo.ExpectedCalls = nil @@ -915,7 +939,7 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { }, { name: "error updating share", - externalPart: externalPart, + externalPart: projectPart, mock: func() { projectRepo.ExpectedCalls = nil projectRepo.On("GetEncryptionPart", mock.Anything, mock.Anything).Return(storedPart, nil) @@ -945,7 +969,9 @@ func TestProjectApplication_RegisterEncryptionKey(t *testing.T) { providerRepo := new(providermockrepo.MockProviderRepository) projectService := projectsvc.New(projectRepo) providerService := providersvc.New(providerRepo) - app := New(projectService, projectRepo, providerService, providerRepo, shareRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) tc := []struct { name string diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go index 71a13f0..837892a 100644 --- a/internal/applications/shareapp/app.go +++ b/internal/applications/shareapp/app.go @@ -2,6 +2,7 @@ package shareapp import ( "context" + "fmt" "go.openfort.xyz/shield/internal/core/ports/factories" "log/slog" @@ -118,6 +119,8 @@ func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share return nil, fromDomainError(err) } + fmt.Println(shr) + var opt options for _, o := range opts { o(&opt) @@ -128,6 +131,7 @@ func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share if err != nil { return nil, err } + fmt.Println(encryptionKey) cypher := a.encryptionFactory.CreateEncryptionStrategy(encryptionKey) shr.Secret, err = cypher.Decrypt(shr.Secret) diff --git a/internal/applications/shareapp/app_test.go b/internal/applications/shareapp/app_test.go index a4f1a79..2e4c9b1 100644 --- a/internal/applications/shareapp/app_test.go +++ b/internal/applications/shareapp/app_test.go @@ -5,13 +5,15 @@ import ( "errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.openfort.xyz/shield/internal/adapters/encryption" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/encryptionpartsmockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" "go.openfort.xyz/shield/internal/adapters/repositories/mocks/sharemockrepo" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/services/sharesvc" "go.openfort.xyz/shield/pkg/contexter" - "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/pkg/random" "testing" ) @@ -20,20 +22,25 @@ func TestShareApplication_GetShare(t *testing.T) { ctx = contexter.WithUserID(ctx, "user_id") shareRepo := new(sharemockrepo.MockShareRepository) projectRepo := new(projectmockrepo.MockProjectRepository) - shareSvc := sharesvc.New(shareRepo) - app := New(shareSvc, shareRepo, projectRepo) - storedPart, externalPart, err := cypher.GenerateEncryptionKey() + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + shareSvc := sharesvc.New(shareRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + key, err := random.GenerateRandomString(32) if err != nil { - t.Fatalf("failed to generate encryption key: %v", err) + t.Fatalf(key) } - encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, externalPart) + + reconstructor := encryptionFactory.CreateReconstructionStrategy() + storedPart, projectPart, err := reconstructor.Split(key) if err != nil { - t.Fatalf("failed to reconstruct encryption key: %v", err) + t.Fatalf("failed to generate encryption key: %v", err) } - encryptedSecret, err := cypher.Encrypt("secret", encryptionKey) + cypher := encryptionFactory.CreateEncryptionStrategy(key) + encryptedSecret, err := cypher.Encrypt("secret") if err != nil { - t.Fatalf("failed to encrypt secret: %v", err) + t.Fatalf("failed to cypher secret: %v", err) } plainShare := &share.Share{ @@ -77,44 +84,66 @@ func TestShareApplication_GetShare(t *testing.T) { wantErr: nil, want: decryptedShare, mock: func() { + tmpEncryptedShare := *encryptedShare shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), + }, + }, + { + name: "encrypted success with session", + wantErr: nil, + want: decryptedShare, + mock: func() { + tmpEncryptedShare := *encryptedShare + shareRepo.ExpectedCalls = nil + projectRepo.ExpectedCalls = nil + encryptionPartsRepo.ExpectedCalls = nil + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) + projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) + encryptionPartsRepo.On("Get", mock.Anything, "sessionID").Return(projectPart, nil) + encryptionPartsRepo.On("Delete", mock.Anything, "sessionID").Return(nil) + }, + opts: []Option{ + WithEncryptionSession("sessionID"), }, }, { name: "encryption part required", wantErr: ErrEncryptionPartRequired, mock: func() { + tmpEncryptedShare := *encryptedShare shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) }, }, { name: "encryption not configured", wantErr: ErrEncryptionNotConfigured, mock: func() { + tmpEncryptedShare := *encryptedShare shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), }, }, { name: "invalid encryption part", wantErr: ErrInvalidEncryptionPart, mock: func() { + tmpEncryptedShare := *encryptedShare shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) }, opts: []Option{ @@ -131,7 +160,7 @@ func TestShareApplication_GetShare(t *testing.T) { projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), }, }, { @@ -156,26 +185,28 @@ func TestShareApplication_GetShare(t *testing.T) { name: "get encryption part repository error", wantErr: ErrInternal, mock: func() { + tmpEncryptedShare := *encryptedShare shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", errors.New("repository error")) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), }, }, { name: "encryption part not found", wantErr: ErrEncryptionNotConfigured, mock: func() { + tmpEncryptedShare := *encryptedShare shareRepo.ExpectedCalls = nil projectRepo.ExpectedCalls = nil - shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(encryptedShare, nil) + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&tmpEncryptedShare, nil) projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), }, }, } @@ -196,20 +227,24 @@ func TestShareApplication_RegisterShare(t *testing.T) { ctx = contexter.WithUserID(ctx, "user_id") shareRepo := new(sharemockrepo.MockShareRepository) projectRepo := new(projectmockrepo.MockProjectRepository) - shareSvc := sharesvc.New(shareRepo) - app := New(shareSvc, shareRepo, projectRepo) - storedPart, externalPart, err := cypher.GenerateEncryptionKey() + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + shareSvc := sharesvc.New(shareRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + key, err := random.GenerateRandomString(32) if err != nil { - t.Fatalf("failed to generate encryption key: %v", err) + t.Fatalf(key) } - encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, externalPart) + + storedPart, projectPart, err := encryptionFactory.CreateReconstructionStrategy().Split(key) if err != nil { - t.Fatalf("failed to reconstruct encryption key: %v", err) + t.Fatalf("failed to generate encryption key: %v", err) } - encryptedSecret, err := cypher.Encrypt("secret", encryptionKey) + cypher := encryptionFactory.CreateEncryptionStrategy(key) + encryptedSecret, err := cypher.Encrypt("secret") if err != nil { - t.Fatalf("failed to encrypt secret: %v", err) + t.Fatalf("failed to cypher secret: %v", err) } plainShare := &share.Share{ @@ -256,7 +291,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return(storedPart, nil) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), }, }, { @@ -279,7 +314,7 @@ func TestShareApplication_RegisterShare(t *testing.T) { projectRepo.On("GetEncryptionPart", mock.Anything, "project_id").Return("", domainErrors.ErrEncryptionPartNotFound) }, opts: []Option{ - WithEncryptionPart(externalPart), + WithEncryptionPart(projectPart), }, }, { @@ -331,8 +366,10 @@ func TestShareApplication_DeleteShare(t *testing.T) { ctx = contexter.WithUserID(ctx, "user_id") shareRepo := new(sharemockrepo.MockShareRepository) projectRepo := new(projectmockrepo.MockProjectRepository) - shareSvc := sharesvc.New(shareRepo) - app := New(shareSvc, shareRepo, projectRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + shareSvc := sharesvc.New(shareRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) tc := []struct { name string diff --git a/internal/core/ports/strategies/reconstruction.go b/internal/core/ports/strategies/reconstruction.go index c5652e0..c2fab4f 100644 --- a/internal/core/ports/strategies/reconstruction.go +++ b/internal/core/ports/strategies/reconstruction.go @@ -1,6 +1,6 @@ package strategies type ReconstructionStrategy interface { - Split(data string) ([]string, error) - Reconstruct(parts []string) (string, error) + Split(data string) (storedPart string, projectPart string, err error) + Reconstruct(storedPart string, projectPart string) (string, error) } diff --git a/internal/core/services/sharesvc/svc_test.go b/internal/core/services/sharesvc/svc_test.go index 7b905b4..48c653d 100644 --- a/internal/core/services/sharesvc/svc_test.go +++ b/internal/core/services/sharesvc/svc_test.go @@ -3,9 +3,13 @@ package sharesvc import ( "context" "errors" + "go.openfort.xyz/shield/internal/adapters/encryption" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/encryptionpartsmockrepo" + "go.openfort.xyz/shield/internal/adapters/repositories/mocks/projectmockrepo" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/services" "go.openfort.xyz/shield/pkg/cypher" + "go.openfort.xyz/shield/pkg/random" "testing" "github.com/stretchr/testify/mock" @@ -15,7 +19,10 @@ import ( func TestCreateShare(t *testing.T) { mockRepo := new(sharemockrepo.MockShareRepository) - svc := New(mockRepo) + projectRepo := new(projectmockrepo.MockProjectRepository) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + svc := New(mockRepo, encryptionFactory) ctx := context.Background() testUserID := "test-user" testData := "test-data" @@ -30,11 +37,18 @@ func TestCreateShare(t *testing.T) { Entropy: share.EntropyProject, }, } - storedPart, externalPart, err := cypher.GenerateEncryptionKey() + key, err := random.GenerateRandomString(32) + if err != nil { + t.Fatalf(key) + } + + reconstructor := encryptionFactory.CreateReconstructionStrategy() + storedPart, projectPart, err := reconstructor.Split(key) if err != nil { t.Fatalf("failed to generate encryption key: %v", err) } - encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, externalPart) + + encryptionKey, err := cypher.ReconstructEncryptionKey(storedPart, projectPart) if err != nil { t.Fatalf("failed to reconstruct encryption key: %v", err) } diff --git a/internal/core/services/usersvc/svc_test.go b/internal/core/services/usersvc/svc_test.go index 943e56a..cad1be7 100644 --- a/internal/core/services/usersvc/svc_test.go +++ b/internal/core/services/usersvc/svc_test.go @@ -16,6 +16,10 @@ func TestCreateUser(t *testing.T) { svc := New(mockRepo) ctx := context.Background() + projectID := "project" + providerID := "provider" + externalUserID := "external" + tc := []struct { name string wantErr bool @@ -42,7 +46,7 @@ func TestCreateUser(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { tt.mock() - _, err := svc.Create(ctx, "fdsa") + _, err := svc.GetOrCreate(ctx, projectID, externalUserID, providerID) if (err != nil) != tt.wantErr { t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) return @@ -56,6 +60,10 @@ func TestGetUser(t *testing.T) { svc := New(mockRepo) ctx := context.Background() + projectID := "project" + providerID := "provider" + externalUserID := "external" + tc := []struct { name string wantErr bool @@ -92,7 +100,7 @@ func TestGetUser(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { tt.mock() - _, err := svc.Get(ctx, "fdsa") + _, err := svc.GetOrCreate(ctx, projectID, externalUserID, providerID) if (err != nil) != tt.wantErr { t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) return @@ -166,7 +174,7 @@ func TestGetUserByExternal(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { tt.mock() - _, err := svc.GetByExternal(ctx, "fdsa", "fdsa") + _, err := svc.GetOrCreate(ctx, "project", "user", "provider") if (err != nil) != tt.wantErr { t.Errorf("GetByExternal() error = %v, wantErr %v", err, tt.wantErr) return @@ -261,7 +269,7 @@ func TestCreateExternalUser(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { tt.mock() - _, err := svc.CreateExternal(ctx, "project", "user", "external", "provider") + _, err := svc.GetOrCreate(ctx, "project", "user", "provider") if (err != nil) != tt.wantErr { t.Errorf("CreateExternal() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/cypher/cypher.go b/pkg/cypher/cypher.go index 7cd9dd4..d436b6f 100644 --- a/pkg/cypher/cypher.go +++ b/pkg/cypher/cypher.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "encoding/base64" "errors" + "fmt" "github.com/codahale/sss" "go.openfort.xyz/shield/pkg/random" ) @@ -35,29 +36,35 @@ func Encrypt(plaintext, key string) (string, error) { } func Decrypt(encrypted, key string) (string, error) { + fmt.Printf("encrypted: %s key: %s\n", encrypted, key) encryptedBytes, err := base64.StdEncoding.DecodeString(encrypted) if err != nil { + fmt.Println("error decoding base64") return "", err } keyBytes, err := base64.StdEncoding.DecodeString(key) if err != nil { + fmt.Println("error decoding key") return "", err } block, err := aes.NewCipher(keyBytes) if err != nil { + fmt.Println("error creating cipher") return "", err } aesGCM, err := cipher.NewGCM(block) if err != nil { + fmt.Println("error creating gcm") return "", err } nonceSize := aesGCM.NonceSize() if len(encryptedBytes) < nonceSize { - return "", err + fmt.Println("invalid nonce size") + return "", errors.New("ciphertext too short") } nonce, ciphertext := encryptedBytes[:nonceSize], encryptedBytes[nonceSize:] @@ -74,6 +81,7 @@ func SplitEncryptionKey(key string) (string, string, error) { if err != nil { return "", "", err } + shares, err := sss.Split(2, 2, rawKey) if err != nil { return "", "", err @@ -101,9 +109,9 @@ func ReconstructEncryptionKey(part1, part2 string) (string, error) { return "", err } - subset := make(map[byte][]byte, 2) - subset[0] = rawPart1 - subset[1] = rawPart2 + subset := make(map[byte][]byte) + subset[1] = rawPart1 + subset[2] = rawPart2 key := sss.Combine(subset) From d56759284d9069d6725b5aa553904b0fe9066c81 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Wed, 10 Jul 2024 13:53:23 +0200 Subject: [PATCH 06/10] fix: rm println --- internal/applications/shareapp/app.go | 4 ---- pkg/cypher/cypher.go | 7 ------- 2 files changed, 11 deletions(-) diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go index 837892a..71a13f0 100644 --- a/internal/applications/shareapp/app.go +++ b/internal/applications/shareapp/app.go @@ -2,7 +2,6 @@ package shareapp import ( "context" - "fmt" "go.openfort.xyz/shield/internal/core/ports/factories" "log/slog" @@ -119,8 +118,6 @@ func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share return nil, fromDomainError(err) } - fmt.Println(shr) - var opt options for _, o := range opts { o(&opt) @@ -131,7 +128,6 @@ func (a *ShareApplication) GetShare(ctx context.Context, opts ...Option) (*share if err != nil { return nil, err } - fmt.Println(encryptionKey) cypher := a.encryptionFactory.CreateEncryptionStrategy(encryptionKey) shr.Secret, err = cypher.Decrypt(shr.Secret) diff --git a/pkg/cypher/cypher.go b/pkg/cypher/cypher.go index d436b6f..e235466 100644 --- a/pkg/cypher/cypher.go +++ b/pkg/cypher/cypher.go @@ -5,7 +5,6 @@ import ( "crypto/cipher" "encoding/base64" "errors" - "fmt" "github.com/codahale/sss" "go.openfort.xyz/shield/pkg/random" ) @@ -36,34 +35,28 @@ func Encrypt(plaintext, key string) (string, error) { } func Decrypt(encrypted, key string) (string, error) { - fmt.Printf("encrypted: %s key: %s\n", encrypted, key) encryptedBytes, err := base64.StdEncoding.DecodeString(encrypted) if err != nil { - fmt.Println("error decoding base64") return "", err } keyBytes, err := base64.StdEncoding.DecodeString(key) if err != nil { - fmt.Println("error decoding key") return "", err } block, err := aes.NewCipher(keyBytes) if err != nil { - fmt.Println("error creating cipher") return "", err } aesGCM, err := cipher.NewGCM(block) if err != nil { - fmt.Println("error creating gcm") return "", err } nonceSize := aesGCM.NonceSize() if len(encryptedBytes) < nonceSize { - fmt.Println("invalid nonce size") return "", errors.New("ciphertext too short") } From 2fedb8d1da010d43614c6a9f3ecf43868689b1d9 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Wed, 10 Jul 2024 14:17:35 +0200 Subject: [PATCH 07/10] feat: tests --- internal/applications/projectapp/app_test.go | 46 +++++++++++++ internal/applications/shareapp/app_test.go | 70 ++++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/internal/applications/projectapp/app_test.go b/internal/applications/projectapp/app_test.go index 766110e..e4f64ab 100644 --- a/internal/applications/projectapp/app_test.go +++ b/internal/applications/projectapp/app_test.go @@ -1023,3 +1023,49 @@ func TestProjectApplication_RegisterEncryptionKey(t *testing.T) { }) } } + +func TestProjectApplication_RegisterEncryptionSession(t *testing.T) { + ctx := contexter.WithProjectID(context.Background(), "project_id") + ctx = contexter.WithUserID(ctx, "user_id") + shareRepo := new(sharemockrepo.MockShareRepository) + projectRepo := new(projectmockrepo.MockProjectRepository) + providerRepo := new(providermockrepo.MockProviderRepository) + projectService := projectsvc.New(projectRepo) + providerService := providersvc.New(providerRepo) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + app := New(projectService, projectRepo, providerService, providerRepo, shareRepo, encryptionFactory, encryptionPartsRepo) + + tc := []struct { + name string + wantErr error + mock func() + }{ + { + name: "success", + wantErr: nil, + mock: func() { + encryptionPartsRepo.ExpectedCalls = nil + encryptionPartsRepo.On("Set", mock.Anything, mock.Anything, mock.Anything).Return(nil) + }, + }, + { + name: "error setting encryption session", + wantErr: ErrInternal, + mock: func() { + encryptionPartsRepo.ExpectedCalls = nil + encryptionPartsRepo.On("Set", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("repository error")) + }, + }, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + tt.mock() + ass := assert.New(t) + _, err := app.RegisterEncryptionSession(ctx, "encryptionPart") + ass.Equal(tt.wantErr, err) + }) + } + +} diff --git a/internal/applications/shareapp/app_test.go b/internal/applications/shareapp/app_test.go index 2e4c9b1..2a44c93 100644 --- a/internal/applications/shareapp/app_test.go +++ b/internal/applications/shareapp/app_test.go @@ -421,3 +421,73 @@ func TestShareApplication_DeleteShare(t *testing.T) { }) } } + +func TestShareApplication_UpdateShare(t *testing.T) { + ctx := contexter.WithProjectID(context.Background(), "project_id") + ctx = contexter.WithUserID(ctx, "user_id") + shareRepo := new(sharemockrepo.MockShareRepository) + projectRepo := new(projectmockrepo.MockProjectRepository) + encryptionPartsRepo := new(encryptionpartsmockrepo.MockEncryptionPartsRepository) + encryptionFactory := encryption.NewEncryptionFactory(encryptionPartsRepo, projectRepo) + shareSvc := sharesvc.New(shareRepo, encryptionFactory) + app := New(shareSvc, shareRepo, projectRepo, encryptionFactory) + updates := &share.Share{ + ID: "share-id", + Secret: "secret", + UserID: "user_id", + EncryptionParameters: nil, + } + + tc := []struct { + name string + wantErr error + mock func() + updates *share.Share + }{ + { + name: "success", + wantErr: nil, + updates: updates, + mock: func() { + shareRepo.ExpectedCalls = nil + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&share.Share{ID: "share-id"}, nil) + shareRepo.On("Update", mock.Anything, mock.Anything).Return(nil) + }, + }, + { + name: "share not found", + wantErr: ErrShareNotFound, + mock: func() { + shareRepo.ExpectedCalls = nil + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(nil, domainErrors.ErrShareNotFound) + }, + }, + { + name: "repository error", + wantErr: ErrInternal, + mock: func() { + shareRepo.ExpectedCalls = nil + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(nil, errors.New("repository error")) + }, + }, + { + name: "delete error", + updates: updates, + wantErr: ErrInternal, + mock: func() { + shareRepo.ExpectedCalls = nil + shareRepo.On("GetByUserID", mock.Anything, "user_id").Return(&share.Share{ID: "share-id"}, nil) + shareRepo.On("Update", mock.Anything, mock.Anything).Return(errors.New("repository error")) + }, + }, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + tt.mock() + ass := assert.New(t) + _, err := app.UpdateShare(ctx, tt.updates) + ass.ErrorIs(tt.wantErr, err) + }) + } +} From 4150054b289ff07b03c69580253dfa6dc965c836 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Thu, 11 Jul 2024 12:30:51 +0200 Subject: [PATCH 08/10] feat: tests --- di/wire.go | 2 +- di/wire_gen.go | 2 +- internal/adapters/authenticators/factory.go | 8 +- .../identity/custom_identity/custom.go | 7 +- .../authenticators/identity/factory.go | 16 +- .../identity/openfort_identity/config.go | 2 +- .../identity/openfort_identity/openfort.go | 13 +- .../project_authenticator.go | 5 +- .../user_authenticator/user_authenticator.go | 5 +- .../aes_encryption_strategy/strategy.go | 2 +- internal/adapters/encryption/factory.go | 20 +- .../encryption/plain_builder/builder.go | 14 +- .../encryption/session_builder/builder.go | 9 +- .../sss_reconstruction_strategy/strategy.go | 34 ++- .../handlers/rest/authmdw/middleware.go | 20 +- internal/adapters/handlers/rest/server.go | 6 +- .../adapters/handlers/rest/sharehdl/errors.go | 1 + .../adapters/handlers/rest/sharehdl/parser.go | 50 ++-- .../bunt/encryptionpartsrepo/repo.go | 15 +- .../mocks/encryptionpartsmockrepo/repo.go | 8 +- internal/adapters/repositories/sql/client.go | 3 +- .../repositories/sql/projectrepo/repo.go | 1 + .../repositories/sql/providerrepo/repo.go | 4 +- .../repositories/sql/sharerepo/parser.go | 63 ++++- .../repositories/sql/sharerepo/repo.go | 5 +- .../repositories/sql/userrepo/repo.go | 6 +- internal/applications/projectapp/app.go | 9 +- internal/applications/projectapp/app_test.go | 28 +-- internal/applications/projectapp/errors.go | 1 + internal/applications/shareapp/app.go | 13 +- internal/applications/shareapp/app_test.go | 34 +-- internal/applications/shareapp/errors.go | 1 + internal/core/domain/errors/project.go | 14 +- .../domain/share/encryption_parameters.go | 1 - internal/core/domain/share/entropy.go | 2 +- internal/core/domain/share/share.go | 3 +- .../core/ports/factories/authentication.go | 1 + .../ports/repositories/encryption_parts.go | 6 +- internal/core/services/projectsvc/svc.go | 3 +- internal/core/services/providersvc/svc.go | 3 +- internal/core/services/sharesvc/svc.go | 3 +- internal/core/services/sharesvc/svc_test.go | 8 +- internal/core/services/usersvc/svc.go | 40 +--- internal/core/services/usersvc/svc_test.go | 222 ++---------------- pkg/cypher/cypher.go | 1 + 45 files changed, 318 insertions(+), 396 deletions(-) diff --git a/di/wire.go b/di/wire.go index dab64ca..e87c5f7 100644 --- a/di/wire.go +++ b/di/wire.go @@ -177,7 +177,7 @@ func ProvideAuthenticationFactory() (f factories.AuthenticationFactory, err erro func ProvideIdentityFactory() (f factories.IdentityFactory, err error) { wire.Build( identity.NewIdentityFactory, - openfort_identity.GetConfigFromEnv, + ofidty.GetConfigFromEnv, ProvideSQLProviderRepository, ) diff --git a/di/wire_gen.go b/di/wire_gen.go index c312b64..468c3d7 100644 --- a/di/wire_gen.go +++ b/di/wire_gen.go @@ -218,7 +218,7 @@ func ProvideAuthenticationFactory() (factories.AuthenticationFactory, error) { } func ProvideIdentityFactory() (factories.IdentityFactory, error) { - config, err := openfort_identity.GetConfigFromEnv() + config, err := ofidty.GetConfigFromEnv() if err != nil { return nil, err } diff --git a/internal/adapters/authenticators/factory.go b/internal/adapters/authenticators/factory.go index eea61c1..58d644a 100644 --- a/internal/adapters/authenticators/factory.go +++ b/internal/adapters/authenticators/factory.go @@ -1,8 +1,8 @@ package authenticators import ( - "go.openfort.xyz/shield/internal/adapters/authenticators/project_authenticator" - "go.openfort.xyz/shield/internal/adapters/authenticators/user_authenticator" + projauth "go.openfort.xyz/shield/internal/adapters/authenticators/project_authenticator" + usrauth "go.openfort.xyz/shield/internal/adapters/authenticators/user_authenticator" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -21,9 +21,9 @@ func NewAuthenticatorFactory(projectRepo repositories.ProjectRepository, userSer } func (f *authenticatorFactory) CreateProjectAuthenticator(apiKey, apiSecret string) factories.Authenticator { - return project_authenticator.NewProjectAuthenticator(f.projectRepo, apiKey, apiSecret) + return projauth.NewProjectAuthenticator(f.projectRepo, apiKey, apiSecret) } func (f *authenticatorFactory) CreateUserAuthenticator(apiKey, token string, identityFactory factories.Identity) factories.Authenticator { - return user_authenticator.NewUserAuthenticator(f.projectRepo, f.userService, apiKey, token, identityFactory) + return usrauth.NewUserAuthenticator(f.projectRepo, f.userService, apiKey, token, identityFactory) } diff --git a/internal/adapters/authenticators/identity/custom_identity/custom.go b/internal/adapters/authenticators/identity/custom_identity/custom.go index 5ed9982..6f9de78 100644 --- a/internal/adapters/authenticators/identity/custom_identity/custom.go +++ b/internal/adapters/authenticators/identity/custom_identity/custom.go @@ -1,11 +1,12 @@ -package custom_identity +package cstmidty import ( "context" + "log/slog" + "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/pkg/jwk" - "log/slog" "github.com/golang-jwt/jwt/v5" @@ -40,7 +41,7 @@ func (c *CustomIdentityFactory) Identify(ctx context.Context, token string) (str case c.config.PEM != "" && c.config.KeyType != provider.KeyTypeUnknown: externalUserID, err = c.validatePEM(token) case c.config.JWK != "": - externalUserID, err = jwk.Validate(token, c.config.JWK) // TODO parse error + externalUserID, err = jwk.Validate(token, c.config.JWK) default: return "", errors.ErrProviderMisconfigured } diff --git a/internal/adapters/authenticators/identity/factory.go b/internal/adapters/authenticators/identity/factory.go index c3efc6f..66db395 100644 --- a/internal/adapters/authenticators/identity/factory.go +++ b/internal/adapters/authenticators/identity/factory.go @@ -3,11 +3,13 @@ package identity import ( "context" "errors" - "go.openfort.xyz/shield/internal/adapters/authenticators/identity/custom_identity" - "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + "log/slog" + + cstmidty "go.openfort.xyz/shield/internal/adapters/authenticators/identity/custom_identity" + ofidty "go.openfort.xyz/shield/internal/adapters/authenticators/identity/openfort_identity" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/factories" - "log/slog" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -15,12 +17,12 @@ import ( ) type identityFactory struct { - config *openfort_identity.Config + config *ofidty.Config repo repositories.ProviderRepository logger *slog.Logger } -func NewIdentityFactory(cfg *openfort_identity.Config, repo repositories.ProviderRepository) factories.IdentityFactory { +func NewIdentityFactory(cfg *ofidty.Config, repo repositories.ProviderRepository) factories.IdentityFactory { return &identityFactory{ config: cfg, repo: repo, @@ -43,7 +45,7 @@ func (p *identityFactory) CreateCustomIdentity(ctx context.Context, apiKey strin return nil, domainErrors.ErrProviderConfigMismatch } - return custom_identity.NewCustomIdentityFactory(config), nil + return cstmidty.NewCustomIdentityFactory(config), nil } func (p *identityFactory) CreateOpenfortIdentity(ctx context.Context, apiKey string, authenticationProvider, tokenType *string) (factories.Identity, error) { @@ -61,5 +63,5 @@ func (p *identityFactory) CreateOpenfortIdentity(ctx context.Context, apiKey str return nil, domainErrors.ErrProviderConfigMismatch } - return openfort_identity.NewOpenfortIdentityFactory(p.config, config, authenticationProvider, tokenType), nil + return ofidty.NewOpenfortIdentityFactory(p.config, config, authenticationProvider, tokenType), nil } diff --git a/internal/adapters/authenticators/identity/openfort_identity/config.go b/internal/adapters/authenticators/identity/openfort_identity/config.go index f0ec451..33ce9e0 100644 --- a/internal/adapters/authenticators/identity/openfort_identity/config.go +++ b/internal/adapters/authenticators/identity/openfort_identity/config.go @@ -1,4 +1,4 @@ -package openfort_identity +package ofidty import "github.com/caarlos0/env/v10" diff --git a/internal/adapters/authenticators/identity/openfort_identity/openfort.go b/internal/adapters/authenticators/identity/openfort_identity/openfort.go index 5d3b257..0b67264 100644 --- a/internal/adapters/authenticators/identity/openfort_identity/openfort.go +++ b/internal/adapters/authenticators/identity/openfort_identity/openfort.go @@ -1,18 +1,19 @@ -package openfort_identity +package ofidty import ( "bytes" "context" "encoding/json" "fmt" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" - "go.openfort.xyz/shield/internal/core/ports/factories" - "go.openfort.xyz/shield/pkg/jwk" "io" "log/slog" "net/http" "time" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/pkg/jwk" + "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/pkg/logger" ) @@ -54,8 +55,8 @@ func (o *OpenfortIdentityFactory) Identify(ctx context.Context, token string) (s return o.accessToken(ctx, token) } -func (o *OpenfortIdentityFactory) accessToken(ctx context.Context, token string) (string, error) { - return jwk.Validate(token, fmt.Sprintf("%s/iam/v1/%s/jwks.json", o.baseURL, o.publishableKey)) // TODO parse error +func (o *OpenfortIdentityFactory) accessToken(_ context.Context, token string) (string, error) { + return jwk.Validate(token, fmt.Sprintf("%s/iam/v1/%s/jwks.json", o.baseURL, o.publishableKey)) } func (o *OpenfortIdentityFactory) thirdParty(ctx context.Context, token, authenticationProvider, tokenType string) (string, error) { diff --git a/internal/adapters/authenticators/project_authenticator/project_authenticator.go b/internal/adapters/authenticators/project_authenticator/project_authenticator.go index 716cdc5..e604777 100644 --- a/internal/adapters/authenticators/project_authenticator/project_authenticator.go +++ b/internal/adapters/authenticators/project_authenticator/project_authenticator.go @@ -1,10 +1,11 @@ -package project_authenticator +package projauth import ( "context" + "log/slog" + "go.openfort.xyz/shield/internal/core/domain/authentication" "go.openfort.xyz/shield/internal/core/ports/factories" - "log/slog" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/pkg/logger" diff --git a/internal/adapters/authenticators/user_authenticator/user_authenticator.go b/internal/adapters/authenticators/user_authenticator/user_authenticator.go index 93b0f3f..89bd12f 100644 --- a/internal/adapters/authenticators/user_authenticator/user_authenticator.go +++ b/internal/adapters/authenticators/user_authenticator/user_authenticator.go @@ -1,10 +1,11 @@ -package user_authenticator +package usrauth import ( "context" + "log/slog" + "go.openfort.xyz/shield/internal/core/domain/authentication" "go.openfort.xyz/shield/internal/core/ports/factories" - "log/slog" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" diff --git a/internal/adapters/encryption/aes_encryption_strategy/strategy.go b/internal/adapters/encryption/aes_encryption_strategy/strategy.go index 057518c..20a0849 100644 --- a/internal/adapters/encryption/aes_encryption_strategy/strategy.go +++ b/internal/adapters/encryption/aes_encryption_strategy/strategy.go @@ -1,4 +1,4 @@ -package aes_encryption_strategy +package aesenc import "go.openfort.xyz/shield/pkg/cypher" diff --git a/internal/adapters/encryption/factory.go b/internal/adapters/encryption/factory.go index efdd2c2..1c4afa6 100644 --- a/internal/adapters/encryption/factory.go +++ b/internal/adapters/encryption/factory.go @@ -1,11 +1,11 @@ package encryption import ( - "errors" - "go.openfort.xyz/shield/internal/adapters/encryption/aes_encryption_strategy" - "go.openfort.xyz/shield/internal/adapters/encryption/plain_builder" - "go.openfort.xyz/shield/internal/adapters/encryption/session_builder" - "go.openfort.xyz/shield/internal/adapters/encryption/sss_reconstruction_strategy" + aesencryptionstrategy "go.openfort.xyz/shield/internal/adapters/encryption/aes_encryption_strategy" + plnbldr "go.openfort.xyz/shield/internal/adapters/encryption/plain_builder" + sessbldr "go.openfort.xyz/shield/internal/adapters/encryption/session_builder" + sssrec "go.openfort.xyz/shield/internal/adapters/encryption/sss_reconstruction_strategy" + "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/builders" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -27,18 +27,18 @@ func NewEncryptionFactory(encryptionPartsRepo repositories.EncryptionPartsReposi func (e *encryptionFactory) CreateEncryptionKeyBuilder(builderType factories.EncryptionKeyBuilderType) (builders.EncryptionKeyBuilder, error) { switch builderType { case factories.Plain: - return plain_builder.NewEncryptionKeyBuilder(e.projectRepo, sss_reconstruction_strategy.NewSSSReconstructionStrategy()), nil + return plnbldr.NewEncryptionKeyBuilder(e.projectRepo, sssrec.NewSSSReconstructionStrategy()), nil case factories.Session: - return session_builder.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo, sss_reconstruction_strategy.NewSSSReconstructionStrategy()), nil + return sessbldr.NewEncryptionKeyBuilder(e.encryptionPartsRepo, e.projectRepo, sssrec.NewSSSReconstructionStrategy()), nil } - return nil, errors.New("invalid builder type") //TODO extract error + return nil, errors.ErrInvalidEncryptionKeyBuilderType } func (e *encryptionFactory) CreateReconstructionStrategy() strategies.ReconstructionStrategy { - return sss_reconstruction_strategy.NewSSSReconstructionStrategy() + return sssrec.NewSSSReconstructionStrategy() } func (e *encryptionFactory) CreateEncryptionStrategy(key string) strategies.EncryptionStrategy { - return aes_encryption_strategy.NewAESEncryptionStrategy(key) + return aesencryptionstrategy.NewAESEncryptionStrategy(key) } diff --git a/internal/adapters/encryption/plain_builder/builder.go b/internal/adapters/encryption/plain_builder/builder.go index cf029c9..1dfed9c 100644 --- a/internal/adapters/encryption/plain_builder/builder.go +++ b/internal/adapters/encryption/plain_builder/builder.go @@ -1,8 +1,10 @@ -package plain_builder +package plnbldr import ( "context" - "errors" + + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/ports/builders" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/strategies" @@ -22,7 +24,7 @@ func NewEncryptionKeyBuilder(repo repositories.ProjectRepository, reconstruction } } -func (b *plainBuilder) SetProjectPart(ctx context.Context, identifier string) error { +func (b *plainBuilder) SetProjectPart(_ context.Context, identifier string) error { b.projectPart = identifier return nil } @@ -37,13 +39,13 @@ func (b *plainBuilder) SetDatabasePart(ctx context.Context, identifier string) e return nil } -func (b *plainBuilder) Build(ctx context.Context) (string, error) { +func (b *plainBuilder) Build(_ context.Context) (string, error) { if b.projectPart == "" { - return "", errors.New("project part is required") // TODO extract error + return "", domainErrors.ErrProjectPartRequired } if b.databasePart == "" { - return "", errors.New("database part is required") // TODO extract error + return "", domainErrors.ErrDatabasePartRequired } return b.reconstructionStrategy.Reconstruct(b.databasePart, b.projectPart) diff --git a/internal/adapters/encryption/session_builder/builder.go b/internal/adapters/encryption/session_builder/builder.go index fd86f90..d6fd856 100644 --- a/internal/adapters/encryption/session_builder/builder.go +++ b/internal/adapters/encryption/session_builder/builder.go @@ -1,8 +1,9 @@ -package session_builder +package sessbldr import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/builders" "go.openfort.xyz/shield/internal/core/ports/repositories" @@ -53,13 +54,13 @@ func (b *sessionBuilder) SetDatabasePart(ctx context.Context, identifier string) return nil } -func (b *sessionBuilder) Build(ctx context.Context) (string, error) { +func (b *sessionBuilder) Build(_ context.Context) (string, error) { if b.projectPart == "" { - return "", errors.New("project part is required") // TODO extract error + return "", domainErrors.ErrProjectPartRequired } if b.databasePart == "" { - return "", errors.New("database part is required") // TODO extract error + return "", domainErrors.ErrDatabasePartRequired } return b.reconstructionStrategy.Reconstruct(b.databasePart, b.projectPart) diff --git a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go index 75e04de..b058a0e 100644 --- a/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go +++ b/internal/adapters/encryption/sss_reconstruction_strategy/strategy.go @@ -1,10 +1,15 @@ -package sss_reconstruction_strategy +package sssrec import ( + "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/strategies" "go.openfort.xyz/shield/pkg/cypher" ) +const ( + MaxReties = 5 +) + type SSSReconstructionStrategy struct{} func NewSSSReconstructionStrategy() strategies.ReconstructionStrategy { @@ -12,9 +17,34 @@ func NewSSSReconstructionStrategy() strategies.ReconstructionStrategy { } func (s *SSSReconstructionStrategy) Split(data string) (storedPart string, projectPart string, err error) { - return cypher.SplitEncryptionKey(data) + for i := 0; i < MaxReties; i++ { + storedPart, projectPart, err = cypher.SplitEncryptionKey(data) + if err != nil { + continue + } + + err = s.validateSplit(data, storedPart, projectPart) + if err == nil { + return + } + } + + return } func (s *SSSReconstructionStrategy) Reconstruct(storedPart string, projectPart string) (string, error) { return cypher.ReconstructEncryptionKey(storedPart, projectPart) } + +func (s *SSSReconstructionStrategy) validateSplit(data string, storedPart string, projectPart string) error { + reconstructed, err := s.Reconstruct(storedPart, projectPart) + if err != nil { + return err + } + + if data != reconstructed { + return errors.ErrReconstructedKeyMismatch + } + + return nil +} diff --git a/internal/adapters/handlers/rest/authmdw/middleware.go b/internal/adapters/handlers/rest/authmdw/middleware.go index ab69a94..3b54cc6 100644 --- a/internal/adapters/handlers/rest/authmdw/middleware.go +++ b/internal/adapters/handlers/rest/authmdw/middleware.go @@ -1,11 +1,12 @@ package authmdw import ( - "go.openfort.xyz/shield/internal/core/ports/factories" - "go.openfort.xyz/shield/internal/core/ports/services" "net/http" "strings" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/ports/services" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/pkg/contexter" ) @@ -84,11 +85,12 @@ func (m *Middleware) PreRegisterUser(next http.Handler) http.Handler { var identity factories.Identity var err error - if providerStr == AuthenticationTypeCustom { + switch providerStr { + case AuthenticationTypeCustom: identity, err = m.identityFactory.CreateCustomIdentity(r.Context(), apiKey) - } else if providerStr == AuthenticationTypeOpenfort { + case AuthenticationTypeOpenfort: identity, err = m.identityFactory.CreateOpenfortIdentity(r.Context(), apiKey, nil, nil) - } else { + default: api.RespondWithError(w, api.ErrInvalidAuthProvider) return } @@ -138,9 +140,11 @@ func (m *Middleware) AuthenticateUser(next http.Handler) http.Handler { var identity factories.Identity var err error - if providerStr == AuthenticationTypeCustom { + + switch providerStr { + case AuthenticationTypeCustom: identity, err = m.identityFactory.CreateCustomIdentity(r.Context(), apiKey) - } else if providerStr == AuthenticationTypeOpenfort { + case AuthenticationTypeOpenfort: var openfortProvider *string if r.Header.Get(OpenfortProviderHeader) != "" { openfortProvider = new(string) @@ -152,7 +156,7 @@ func (m *Middleware) AuthenticateUser(next http.Handler) http.Handler { *openfortTokenType = r.Header.Get(OpenfortTokenTypeHeader) } identity, err = m.identityFactory.CreateOpenfortIdentity(r.Context(), apiKey, openfortProvider, openfortTokenType) - } else { + default: api.RespondWithError(w, api.ErrInvalidAuthProvider) return } diff --git a/internal/adapters/handlers/rest/server.go b/internal/adapters/handlers/rest/server.go index f47fc4a..f90a31d 100644 --- a/internal/adapters/handlers/rest/server.go +++ b/internal/adapters/handlers/rest/server.go @@ -3,12 +3,13 @@ package rest import ( "context" "fmt" - "go.openfort.xyz/shield/internal/core/ports/factories" - "go.openfort.xyz/shield/internal/core/ports/services" "log/slog" "net/http" "strings" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/ports/services" + "github.com/gorilla/mux" "github.com/rs/cors" "go.openfort.xyz/shield/internal/adapters/handlers/rest/authmdw" @@ -77,6 +78,7 @@ func (s *Server) Start(ctx context.Context) error { u.HandleFunc("", shareHdl.GetShare).Methods(http.MethodGet) u.HandleFunc("", shareHdl.RegisterShare).Methods(http.MethodPost) u.HandleFunc("", shareHdl.DeleteShare).Methods(http.MethodDelete) + u.HandleFunc("", shareHdl.UpdateShare).Methods(http.MethodPut) a := r.PathPrefix("/admin").Subrouter() a.Use(authMdw.AuthenticateAPISecret) diff --git a/internal/adapters/handlers/rest/sharehdl/errors.go b/internal/adapters/handlers/rest/sharehdl/errors.go index 4453940..f4fca0b 100644 --- a/internal/adapters/handlers/rest/sharehdl/errors.go +++ b/internal/adapters/handlers/rest/sharehdl/errors.go @@ -2,6 +2,7 @@ package sharehdl import ( "errors" + "go.openfort.xyz/shield/internal/adapters/handlers/rest/api" "go.openfort.xyz/shield/internal/applications/shareapp" ) diff --git a/internal/adapters/handlers/rest/sharehdl/parser.go b/internal/adapters/handlers/rest/sharehdl/parser.go index 4489256..f1648a2 100644 --- a/internal/adapters/handlers/rest/sharehdl/parser.go +++ b/internal/adapters/handlers/rest/sharehdl/parser.go @@ -24,49 +24,65 @@ func newParser() *parser { func (p *parser) toDomain(s *Share) *share.Share { shr := &share.Share{ - Secret: s.Secret, - EncryptionParameters: &share.EncryptionParameters{ - Entropy: p.mapEntropyDomain[s.Entropy], - }, + Secret: s.Secret, + Entropy: p.mapEntropyDomain[s.Entropy], } if s.EncryptionPart != "" || s.EncryptionSession != "" { - shr.EncryptionParameters.Entropy = share.EntropyProject + shr.Entropy = share.EntropyProject } if s.Salt != "" { + if shr.EncryptionParameters == nil { + shr.EncryptionParameters = new(share.EncryptionParameters) + } shr.EncryptionParameters.Salt = s.Salt } if s.Iterations != 0 { + if shr.EncryptionParameters == nil { + shr.EncryptionParameters = new(share.EncryptionParameters) + } shr.EncryptionParameters.Iterations = s.Iterations } if s.Length != 0 { + if shr.EncryptionParameters == nil { + shr.EncryptionParameters = new(share.EncryptionParameters) + } shr.EncryptionParameters.Length = s.Length } if s.Digest != "" { + if shr.EncryptionParameters == nil { + shr.EncryptionParameters = new(share.EncryptionParameters) + } shr.EncryptionParameters.Digest = s.Digest } + if shr.EncryptionParameters != nil { + shr.Entropy = share.EntropyUser + } + return shr } func (p *parser) fromDomain(s *share.Share) *Share { shr := &Share{ Secret: s.Secret, - Entropy: p.mapDomainEntropy[s.EncryptionParameters.Entropy], + Entropy: p.mapDomainEntropy[s.Entropy], } - if s.EncryptionParameters.Salt != "" { - shr.Salt = s.EncryptionParameters.Salt - } - if s.EncryptionParameters.Iterations != 0 { - shr.Iterations = s.EncryptionParameters.Iterations - } - if s.EncryptionParameters.Length != 0 { - shr.Length = s.EncryptionParameters.Length - } - if s.EncryptionParameters.Digest != "" { - shr.Digest = s.EncryptionParameters.Digest + if s.EncryptionParameters != nil { + if s.EncryptionParameters.Salt != "" { + shr.Salt = s.EncryptionParameters.Salt + } + if s.EncryptionParameters.Iterations != 0 { + shr.Iterations = s.EncryptionParameters.Iterations + } + if s.EncryptionParameters.Length != 0 { + shr.Length = s.EncryptionParameters.Length + } + if s.EncryptionParameters.Digest != "" { + shr.Digest = s.EncryptionParameters.Digest + } } return shr diff --git a/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go b/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go index cbb09fe..a17ec8a 100644 --- a/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go +++ b/internal/adapters/repositories/bunt/encryptionpartsrepo/repo.go @@ -3,12 +3,13 @@ package encryptionpartsrepo import ( "context" "errors" + "log/slog" + "github.com/tidwall/buntdb" "go.openfort.xyz/shield/internal/adapters/repositories/bunt" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/pkg/logger" - "log/slog" ) type repository struct { @@ -25,11 +26,11 @@ func New(db *bunt.Client) repositories.EncryptionPartsRepository { } } -func (r *repository) Get(ctx context.Context, sessionId string) (string, error) { +func (r *repository) Get(ctx context.Context, sessionID string) (string, error) { var part string err := r.db.View(func(tx *buntdb.Tx) error { var err error - part, err = tx.Get(sessionId) + part, err = tx.Get(sessionID) return err }) if err != nil { @@ -47,9 +48,9 @@ func (r *repository) Get(ctx context.Context, sessionId string) (string, error) return part, nil } -func (r *repository) Set(ctx context.Context, sessionId, part string) error { +func (r *repository) Set(ctx context.Context, sessionID, part string) error { return r.db.Update(func(tx *buntdb.Tx) error { - _, _, err := tx.Set(sessionId, part, nil) + _, _, err := tx.Set(sessionID, part, nil) if err != nil { if errors.Is(err, buntdb.ErrIndexExists) { return domainErrors.ErrEncryptionPartAlreadyExists @@ -62,9 +63,9 @@ func (r *repository) Set(ctx context.Context, sessionId, part string) error { }) } -func (r *repository) Delete(ctx context.Context, sessionId string) error { +func (r *repository) Delete(ctx context.Context, sessionID string) error { return r.db.Update(func(tx *buntdb.Tx) error { - _, err := tx.Delete(sessionId) + _, err := tx.Delete(sessionID) if err != nil { if errors.Is(err, buntdb.ErrNotFound) { return domainErrors.ErrEncryptionPartNotFound diff --git a/internal/adapters/repositories/mocks/encryptionpartsmockrepo/repo.go b/internal/adapters/repositories/mocks/encryptionpartsmockrepo/repo.go index 6da07ab..d0bf1f8 100644 --- a/internal/adapters/repositories/mocks/encryptionpartsmockrepo/repo.go +++ b/internal/adapters/repositories/mocks/encryptionpartsmockrepo/repo.go @@ -13,16 +13,16 @@ type MockEncryptionPartsRepository struct { var _ repositories.EncryptionPartsRepository = (*MockEncryptionPartsRepository)(nil) -func (m *MockEncryptionPartsRepository) Get(ctx context.Context, sessionId string) (string, error) { - args := m.Mock.Called(ctx, sessionId) +func (m *MockEncryptionPartsRepository) Get(ctx context.Context, sessionID string) (string, error) { + args := m.Mock.Called(ctx, sessionID) if args.Get(0) == nil { return "", args.Error(1) } return args.Get(0).(string), args.Error(1) } -func (m *MockEncryptionPartsRepository) Set(ctx context.Context, sessionId, part string) error { - args := m.Mock.Called(ctx, sessionId, part) +func (m *MockEncryptionPartsRepository) Set(ctx context.Context, sessionID, part string) error { + args := m.Mock.Called(ctx, sessionID, part) return args.Error(0) } diff --git a/internal/adapters/repositories/sql/client.go b/internal/adapters/repositories/sql/client.go index e895481..44ff645 100644 --- a/internal/adapters/repositories/sql/client.go +++ b/internal/adapters/repositories/sql/client.go @@ -2,11 +2,12 @@ package sql import ( "database/sql" + "path/filepath" + "github.com/pressly/goose" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/gorm" - "path/filepath" ) type Client struct { diff --git a/internal/adapters/repositories/sql/projectrepo/repo.go b/internal/adapters/repositories/sql/projectrepo/repo.go index 26360f0..45499d2 100644 --- a/internal/adapters/repositories/sql/projectrepo/repo.go +++ b/internal/adapters/repositories/sql/projectrepo/repo.go @@ -3,6 +3,7 @@ package projectrepo import ( "context" "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" diff --git a/internal/adapters/repositories/sql/providerrepo/repo.go b/internal/adapters/repositories/sql/providerrepo/repo.go index 32ce52c..29c1d6d 100644 --- a/internal/adapters/repositories/sql/providerrepo/repo.go +++ b/internal/adapters/repositories/sql/providerrepo/repo.go @@ -3,9 +3,10 @@ package providerrepo import ( "context" "errors" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "github.com/google/uuid" "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/provider" @@ -74,7 +75,6 @@ func (r *repository) GetByAPIKeyAndType(ctx context.Context, apiKey string, prov } r.logger.ErrorContext(ctx, "error getting provider", logger.Error(err)) return nil, err - } return r.parser.toDomainProvider(dbProv), nil diff --git a/internal/adapters/repositories/sql/sharerepo/parser.go b/internal/adapters/repositories/sql/sharerepo/parser.go index 1bafde8..c7a0003 100644 --- a/internal/adapters/repositories/sql/sharerepo/parser.go +++ b/internal/adapters/repositories/sql/sharerepo/parser.go @@ -2,6 +2,7 @@ package sharerepo import ( "go.openfort.xyz/shield/internal/core/domain/share" + "gorm.io/gorm" ) type parser struct { @@ -25,20 +26,27 @@ func newParser() *parser { } func (p *parser) toDomain(s *Share) *share.Share { - encryptionParameters := &share.EncryptionParameters{ - Entropy: p.mapEntropyDomain[s.Entropy], - } - + var encryptionParameters *share.EncryptionParameters if s.Salt != "" { + encryptionParameters = new(share.EncryptionParameters) encryptionParameters.Salt = s.Salt } if s.Iterations != 0 { + if encryptionParameters == nil { + encryptionParameters = new(share.EncryptionParameters) + } encryptionParameters.Iterations = s.Iterations } if s.Length != 0 { + if encryptionParameters == nil { + encryptionParameters = new(share.EncryptionParameters) + } encryptionParameters.Length = s.Length } if s.Digest != "" { + if encryptionParameters == nil { + encryptionParameters = new(share.EncryptionParameters) + } encryptionParameters.Digest = s.Digest } @@ -46,19 +54,20 @@ func (p *parser) toDomain(s *Share) *share.Share { ID: s.ID, Secret: s.Data, UserID: s.UserID, + Entropy: p.mapEntropyDomain[s.Entropy], EncryptionParameters: encryptionParameters, } } func (p *parser) toDatabase(s *share.Share) *Share { shr := &Share{ - ID: s.ID, - Data: s.Secret, - UserID: s.UserID, + ID: s.ID, + Data: s.Secret, + UserID: s.UserID, + Entropy: p.mapDomainEntropy[s.Entropy], } if s.EncryptionParameters != nil { - shr.Entropy = p.mapDomainEntropy[s.EncryptionParameters.Entropy] if s.EncryptionParameters.Salt != "" { shr.Salt = s.EncryptionParameters.Salt } @@ -71,9 +80,43 @@ func (p *parser) toDatabase(s *share.Share) *Share { if s.EncryptionParameters.Digest != "" { shr.Digest = s.EncryptionParameters.Digest } - } else { - shr.Entropy = EntropyNone } return shr } + +func (p *parser) toUpdates(s *share.Share) map[string]interface{} { + updates := make(map[string]interface{}) + + if s.Secret != "" { + updates["data"] = s.Secret + } + + if s.Entropy != 0 { + updates["entropy"] = p.mapDomainEntropy[s.Entropy] + } + + if s.Entropy != share.EntropyUser { + updates["salt"] = gorm.Expr("NULL") + updates["iterations"] = gorm.Expr("NULL") + updates["length"] = gorm.Expr("NULL") + updates["digest"] = gorm.Expr("NULL") + } + + if s.EncryptionParameters != nil && s.Entropy == share.EntropyUser { + if s.EncryptionParameters.Salt != "" { + updates["salt"] = s.EncryptionParameters.Salt + } + if s.EncryptionParameters.Iterations != 0 { + updates["iterations"] = s.EncryptionParameters.Iterations + } + if s.EncryptionParameters.Length != 0 { + updates["length"] = s.EncryptionParameters.Length + } + if s.EncryptionParameters.Digest != "" { + updates["digest"] = s.EncryptionParameters.Digest + } + } + + return updates +} diff --git a/internal/adapters/repositories/sql/sharerepo/repo.go b/internal/adapters/repositories/sql/sharerepo/repo.go index 2d61ab5..d5d5877 100644 --- a/internal/adapters/repositories/sql/sharerepo/repo.go +++ b/internal/adapters/repositories/sql/sharerepo/repo.go @@ -3,9 +3,10 @@ package sharerepo import ( "context" "errors" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "github.com/google/uuid" "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/share" @@ -112,7 +113,7 @@ func (r *repository) UpdateProjectEncryption(ctx context.Context, shareID string func (r *repository) Update(ctx context.Context, shr *share.Share) error { r.logger.InfoContext(ctx, "updating share", slog.String("id", shr.ID)) - dbShr := r.parser.toDatabase(shr) + dbShr := r.parser.toUpdates(shr) err := r.db.Model(&Share{}).Where("id = ?", shr.ID).Updates(dbShr).Error if err != nil { r.logger.ErrorContext(ctx, "error updating share", logger.Error(err)) diff --git a/internal/adapters/repositories/sql/userrepo/repo.go b/internal/adapters/repositories/sql/userrepo/repo.go index 0009077..8801205 100644 --- a/internal/adapters/repositories/sql/userrepo/repo.go +++ b/internal/adapters/repositories/sql/userrepo/repo.go @@ -3,9 +3,10 @@ package userrepo import ( "context" "errors" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "github.com/google/uuid" "go.openfort.xyz/shield/internal/adapters/repositories/sql" "go.openfort.xyz/shield/internal/core/domain/user" @@ -93,6 +94,9 @@ func (r *repository) FindExternalBy(ctx context.Context, opts ...repositories.Op var dbExtUsrs []ExternalUser err := r.db.Where(options.query).Find(&dbExtUsrs).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return []*user.ExternalUser{}, nil + } r.logger.ErrorContext(ctx, "error finding external user", logger.Error(err)) return nil, err } diff --git a/internal/applications/projectapp/app.go b/internal/applications/projectapp/app.go index 7a13054..e5b684b 100644 --- a/internal/applications/projectapp/app.go +++ b/internal/applications/projectapp/app.go @@ -3,11 +3,12 @@ package projectapp import ( "context" "errors" + "log/slog" + "github.com/google/uuid" domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/factories" "go.openfort.xyz/shield/pkg/random" - "log/slog" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/domain/provider" @@ -305,7 +306,7 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP var encryptedShares []*share.Share for _, shr := range shares { - if shr.EncryptionParameters != nil && shr.EncryptionParameters.Entropy != share.EntropyNone { + if shr.EncryptionParameters != nil || shr.Entropy != share.EntropyNone { continue } @@ -316,9 +317,7 @@ func (a *ProjectApplication) EncryptProjectShares(ctx context.Context, externalP return fromDomainError(err) } - shr.EncryptionParameters = &share.EncryptionParameters{ - Entropy: share.EntropyProject, - } + shr.Entropy = share.EntropyProject encryptedShares = append(encryptedShares, shr) } diff --git a/internal/applications/projectapp/app_test.go b/internal/applications/projectapp/app_test.go index e4f64ab..ea2f912 100644 --- a/internal/applications/projectapp/app_test.go +++ b/internal/applications/projectapp/app_test.go @@ -851,11 +851,11 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { } encryptedShare := &share.Share{ - ID: "encrypted_share_id", - Secret: "djksalfjadsfds", - UserID: "user_id", + ID: "encrypted_share_id", + Secret: "djksalfjadsfds", + UserID: "user_id", + Entropy: share.EntropyUser, EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyUser, Salt: "somesalt", Iterations: 1000, Length: 256, @@ -864,21 +864,17 @@ func TestProjectApplication_EncryptProjectShares(t *testing.T) { } plainShare := &share.Share{ - ID: "share_id", - Secret: "secret", - UserID: "user_id", - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyNone, - }, + ID: "share_id", + Secret: "secret", + UserID: "user_id", + Entropy: share.EntropyNone, } plainShare2 := &share.Share{ - ID: "share_id", - Secret: "secret", - UserID: "user_id", - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyNone, - }, + ID: "share_id", + Secret: "secret", + UserID: "user_id", + Entropy: share.EntropyNone, } tc := []struct { diff --git a/internal/applications/projectapp/errors.go b/internal/applications/projectapp/errors.go index ec59f36..49d8448 100644 --- a/internal/applications/projectapp/errors.go +++ b/internal/applications/projectapp/errors.go @@ -2,6 +2,7 @@ package projectapp import ( "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" ) diff --git a/internal/applications/shareapp/app.go b/internal/applications/shareapp/app.go index 71a13f0..98ab5e9 100644 --- a/internal/applications/shareapp/app.go +++ b/internal/applications/shareapp/app.go @@ -2,9 +2,10 @@ package shareapp import ( "context" - "go.openfort.xyz/shield/internal/core/ports/factories" "log/slog" + "go.openfort.xyz/shield/internal/core/ports/factories" + "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -71,10 +72,20 @@ func (a *ShareApplication) UpdateShare(ctx context.Context, shr *share.Share, op return nil, fromDomainError(err) } + if shr.Entropy != 0 { + dbShare.Entropy = shr.Entropy + } + if shr.EncryptionParameters != nil { dbShare.EncryptionParameters = shr.EncryptionParameters } + if dbShare.Entropy == share.EntropyNone { + if dbShare.EncryptionParameters != nil { + dbShare.EncryptionParameters = nil + } + } + if shr.Secret != "" { dbShare.Secret = shr.Secret } diff --git a/internal/applications/shareapp/app_test.go b/internal/applications/shareapp/app_test.go index 2a44c93..ea6f2cc 100644 --- a/internal/applications/shareapp/app_test.go +++ b/internal/applications/shareapp/app_test.go @@ -44,22 +44,16 @@ func TestShareApplication_GetShare(t *testing.T) { } plainShare := &share.Share{ - Secret: "secret", - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyNone, - }, + Secret: "secret", + Entropy: share.EntropyNone, } encryptedShare := &share.Share{ - Secret: encryptedSecret, - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyProject, - }, + Secret: encryptedSecret, + Entropy: share.EntropyProject, } decryptedShare := &share.Share{ - Secret: "secret", - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyProject, - }, + Secret: "secret", + Entropy: share.EntropyProject, } tc := []struct { @@ -248,18 +242,14 @@ func TestShareApplication_RegisterShare(t *testing.T) { } plainShare := &share.Share{ - Secret: "secret", - UserID: "user_id", - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyNone, - }, + Secret: "secret", + UserID: "user_id", + Entropy: share.EntropyNone, } encryptedShare := &share.Share{ - Secret: encryptedSecret, - UserID: "user_id", - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyProject, - }, + Secret: encryptedSecret, + UserID: "user_id", + Entropy: share.EntropyProject, } tc := []struct { diff --git a/internal/applications/shareapp/errors.go b/internal/applications/shareapp/errors.go index da14008..02c2af2 100644 --- a/internal/applications/shareapp/errors.go +++ b/internal/applications/shareapp/errors.go @@ -2,6 +2,7 @@ package shareapp import ( "errors" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" ) diff --git a/internal/core/domain/errors/project.go b/internal/core/domain/errors/project.go index 6b9eeca..7d7c4d0 100644 --- a/internal/core/domain/errors/project.go +++ b/internal/core/domain/errors/project.go @@ -3,9 +3,13 @@ package errors import "errors" var ( - ErrProjectNotFound = errors.New("project not found") - ErrEncryptionPartNotFound = errors.New("encryption part not found") - ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") - ErrEncryptionPartRequired = errors.New("encryption part is required") - ErrInvalidEncryptionSession = errors.New("invalid encryption session") + ErrProjectNotFound = errors.New("project not found") + ErrEncryptionPartNotFound = errors.New("encryption part not found") + ErrEncryptionPartAlreadyExists = errors.New("encryption part already exists") + ErrEncryptionPartRequired = errors.New("encryption part is required") + ErrInvalidEncryptionSession = errors.New("invalid encryption session") + ErrInvalidEncryptionKeyBuilderType = errors.New("invalid encryption key builder type") + ErrReconstructedKeyMismatch = errors.New("reconstructed key mismatch") + ErrProjectPartRequired = errors.New("project part is required") + ErrDatabasePartRequired = errors.New("database part is required") ) diff --git a/internal/core/domain/share/encryption_parameters.go b/internal/core/domain/share/encryption_parameters.go index 944cda6..51ad3be 100644 --- a/internal/core/domain/share/encryption_parameters.go +++ b/internal/core/domain/share/encryption_parameters.go @@ -1,7 +1,6 @@ package share type EncryptionParameters struct { - Entropy Entropy Salt string Iterations int Length int diff --git a/internal/core/domain/share/entropy.go b/internal/core/domain/share/entropy.go index 3f5ab50..8c2db22 100644 --- a/internal/core/domain/share/entropy.go +++ b/internal/core/domain/share/entropy.go @@ -3,7 +3,7 @@ package share type Entropy int8 const ( - EntropyNone Entropy = iota + EntropyNone Entropy = iota + 1 EntropyUser EntropyProject ) diff --git a/internal/core/domain/share/share.go b/internal/core/domain/share/share.go index df33e51..55205fd 100644 --- a/internal/core/domain/share/share.go +++ b/internal/core/domain/share/share.go @@ -4,9 +4,10 @@ type Share struct { ID string Secret string UserID string + Entropy Entropy EncryptionParameters *EncryptionParameters } func (s *Share) RequiresEncryption() bool { - return s.EncryptionParameters != nil && s.EncryptionParameters.Entropy == EntropyProject + return s.Entropy == EntropyProject } diff --git a/internal/core/ports/factories/authentication.go b/internal/core/ports/factories/authentication.go index 237e82b..9ca847d 100644 --- a/internal/core/ports/factories/authentication.go +++ b/internal/core/ports/factories/authentication.go @@ -2,6 +2,7 @@ package factories import ( "context" + "go.openfort.xyz/shield/internal/core/domain/authentication" ) diff --git a/internal/core/ports/repositories/encryption_parts.go b/internal/core/ports/repositories/encryption_parts.go index 0e31d54..488a025 100644 --- a/internal/core/ports/repositories/encryption_parts.go +++ b/internal/core/ports/repositories/encryption_parts.go @@ -3,7 +3,7 @@ package repositories import "context" type EncryptionPartsRepository interface { - Get(ctx context.Context, sessionId string) (string, error) - Set(ctx context.Context, sessionId, part string) error - Delete(ctx context.Context, sessionId string) error + Get(ctx context.Context, sessionID string) (string, error) + Set(ctx context.Context, sessionID, part string) error + Delete(ctx context.Context, sessionID string) error } diff --git a/internal/core/services/projectsvc/svc.go b/internal/core/services/projectsvc/svc.go index f40686f..0c11930 100644 --- a/internal/core/services/projectsvc/svc.go +++ b/internal/core/services/projectsvc/svc.go @@ -3,9 +3,10 @@ package projectsvc import ( "context" "errors" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "github.com/google/uuid" "go.openfort.xyz/shield/internal/core/domain/project" "go.openfort.xyz/shield/internal/core/ports/repositories" diff --git a/internal/core/services/providersvc/svc.go b/internal/core/services/providersvc/svc.go index 8ebf3e2..a795341 100644 --- a/internal/core/services/providersvc/svc.go +++ b/internal/core/services/providersvc/svc.go @@ -3,9 +3,10 @@ package providersvc import ( "context" "errors" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" diff --git a/internal/core/services/sharesvc/svc.go b/internal/core/services/sharesvc/svc.go index eaa3bdf..574209e 100644 --- a/internal/core/services/sharesvc/svc.go +++ b/internal/core/services/sharesvc/svc.go @@ -3,9 +3,10 @@ package sharesvc import ( "context" "errors" + "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "go.openfort.xyz/shield/internal/core/ports/factories" - "log/slog" "go.openfort.xyz/shield/internal/core/domain/share" "go.openfort.xyz/shield/internal/core/ports/repositories" diff --git a/internal/core/services/sharesvc/svc_test.go b/internal/core/services/sharesvc/svc_test.go index 48c653d..e219ca5 100644 --- a/internal/core/services/sharesvc/svc_test.go +++ b/internal/core/services/sharesvc/svc_test.go @@ -31,11 +31,9 @@ func TestCreateShare(t *testing.T) { Secret: testData, } testEncryptionShare := &share.Share{ - UserID: testUserID, - Secret: testData, - EncryptionParameters: &share.EncryptionParameters{ - Entropy: share.EntropyProject, - }, + UserID: testUserID, + Secret: testData, + Entropy: share.EntropyProject, } key, err := random.GenerateRandomString(32) if err != nil { diff --git a/internal/core/services/usersvc/svc.go b/internal/core/services/usersvc/svc.go index f2a6593..8b9fff4 100644 --- a/internal/core/services/usersvc/svc.go +++ b/internal/core/services/usersvc/svc.go @@ -3,9 +3,10 @@ package usersvc import ( "context" "errors" - domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" "log/slog" + domainErrors "go.openfort.xyz/shield/internal/core/domain/errors" + "go.openfort.xyz/shield/internal/core/domain/user" "go.openfort.xyz/shield/internal/core/ports/repositories" "go.openfort.xyz/shield/internal/core/ports/services" @@ -42,7 +43,7 @@ func (s *service) GetOrCreate(ctx context.Context, projectID, externalUserID, pr return nil, err } - _, err = s.createExternal(ctx, projectID, usr.ID, externalUserID, providerID) + _, err = s.createExternal(ctx, usr, externalUserID, providerID) if err != nil { s.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) return nil, err @@ -91,43 +92,16 @@ func (s *service) getByExternal(ctx context.Context, externalUserID, providerID return usr, nil } -func (s *service) createExternal(ctx context.Context, projectID, userID, externalUserID, providerID string) (*user.ExternalUser, error) { - s.logger.InfoContext(ctx, "creating external user", slog.String("project_id", projectID)) - - usr, err := s.repo.Get(ctx, userID) - if err != nil { - s.logger.ErrorContext(ctx, "failed to get user", logger.Error(err)) - return nil, err - } - - if usr == nil { - s.logger.ErrorContext(ctx, "user not found", slog.String("user_id", userID)) - return nil, domainErrors.ErrUserNotFound - } - - if usr.ProjectID != projectID { - s.logger.ErrorContext(ctx, "user does not belong to project", slog.String("project_id", projectID), slog.String("user_id", userID)) - return nil, domainErrors.ErrUserNotFound - } - - extUsrs, err := s.repo.FindExternalBy(ctx, s.repo.WithUserID(userID), s.repo.WithProviderID(providerID)) - if err != nil && !errors.Is(err, domainErrors.ErrExternalUserNotFound) { - s.logger.ErrorContext(ctx, "failed to get external user", logger.Error(err)) - return nil, err - } - - if len(extUsrs) != 0 { - s.logger.ErrorContext(ctx, "external user already exists for this user and provider", slog.String("user_id", userID), slog.String("provider_type", providerID)) - return nil, domainErrors.ErrExternalUserAlreadyExists - } +func (s *service) createExternal(ctx context.Context, usr *user.User, externalUserID, providerID string) (*user.ExternalUser, error) { + s.logger.InfoContext(ctx, "creating external user") extUsr := &user.ExternalUser{ - UserID: userID, + UserID: usr.ID, ExternalUserID: externalUserID, ProviderID: providerID, } - err = s.repo.CreateExternal(ctx, extUsr) + err := s.repo.CreateExternal(ctx, extUsr) if err != nil { s.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) return nil, err diff --git a/internal/core/services/usersvc/svc_test.go b/internal/core/services/usersvc/svc_test.go index cad1be7..8f13daf 100644 --- a/internal/core/services/usersvc/svc_test.go +++ b/internal/core/services/usersvc/svc_test.go @@ -11,7 +11,7 @@ import ( "go.openfort.xyz/shield/internal/core/domain/user" ) -func TestCreateUser(t *testing.T) { +func TestService_GetOrCreate(t *testing.T) { mockRepo := new(usermockedrepo.MockUserRepository) svc := New(mockRepo) ctx := context.Background() @@ -20,49 +20,17 @@ func TestCreateUser(t *testing.T) { providerID := "provider" externalUserID := "external" - tc := []struct { - name string - wantErr bool - mock func() - }{ - { - name: "success", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Create", mock.Anything, mock.Anything).Return(nil) - }, - }, - { - name: "failure", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Create", mock.Anything, mock.Anything).Return(errors.New("random error")) - }, - }, + randomUser := &user.User{ + ID: "user-id", + ProjectID: "project-id", } - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - _, err := svc.GetOrCreate(ctx, projectID, externalUserID, providerID) - if (err != nil) != tt.wantErr { - t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) + randomExternalUser := &user.ExternalUser{ + ID: "external-user-id", + UserID: "user-id", + ExternalUserID: "external-id", + ProviderID: "provider-id", } -} - -func TestGetUser(t *testing.T) { - mockRepo := new(usermockedrepo.MockUserRepository) - svc := New(mockRepo) - ctx := context.Background() - - projectID := "project" - providerID := "provider" - externalUserID := "external" tc := []struct { name string @@ -71,196 +39,58 @@ func TestGetUser(t *testing.T) { mock func() }{ { - name: "success", + name: "get success", wantErr: false, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{}, nil) - }, - }, - { - name: "not found", - wantErr: true, - err: domainErrors.ErrUserNotFound, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrUserNotFound) + mockRepo.On("Get", mock.Anything, mock.Anything).Return(randomUser, nil) + mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{randomExternalUser}, nil) }, }, { - name: "failure", + name: "get failed external user", wantErr: true, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, errors.New("random error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - _, err := svc.GetOrCreate(ctx, projectID, externalUserID, providerID) - if (err != nil) != tt.wantErr { - t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("Get() error = %v, wantErr %v", err, tt.err) - return - } - }) - } -} - -func TestGetUserByExternal(t *testing.T) { - mockRepo := new(usermockedrepo.MockUserRepository) - svc := New(mockRepo) - ctx := context.Background() - - tc := []struct { - name string - wantErr bool - err error - mock func() - }{ - { - name: "success", - wantErr: false, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{{}}, nil) - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{}, nil) - }, - }, - { - name: "external not found", - wantErr: true, - err: domainErrors.ErrExternalUserNotFound, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrExternalUserNotFound) - }, - }, - { - name: "external empty", - wantErr: true, - err: domainErrors.ErrExternalUserNotFound, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, nil) + mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, errors.New("random error")) }, }, { - name: "user not found", + name: "get failed to get user", wantErr: true, err: domainErrors.ErrUserNotFound, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{{}}, nil) + mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{randomExternalUser}, nil) mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrUserNotFound) }, }, { - name: "failure", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return(nil, errors.New("random error")) - }, - }, - } - - for _, tt := range tc { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - _, err := svc.GetOrCreate(ctx, "project", "user", "provider") - if (err != nil) != tt.wantErr { - t.Errorf("GetByExternal() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("GetByExternal() error = %v, wantErr %v", err, tt.err) - return - } - }) - } -} - -func TestCreateExternalUser(t *testing.T) { - mockRepo := new(usermockedrepo.MockUserRepository) - svc := New(mockRepo) - ctx := context.Background() - - tc := []struct { - name string - wantErr bool - err error - mock func() - }{ - { - name: "success", + name: "create success", wantErr: false, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{ProjectID: "project"}, nil) mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, nil) + mockRepo.On("Create", mock.Anything, mock.Anything).Return(nil) mockRepo.On("CreateExternal", mock.Anything, mock.Anything).Return(nil) }, }, { - name: "user not found on repo", - wantErr: true, - err: domainErrors.ErrUserNotFound, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, domainErrors.ErrUserNotFound) - }, - }, - { - name: "user empty on repo", - wantErr: true, - err: domainErrors.ErrUserNotFound, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(nil, nil) - }, - }, - { - name: "user not found project mismatch", + name: "create failed to create user", wantErr: true, - err: domainErrors.ErrUserNotFound, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{ProjectID: "noproject"}, nil) - }, - }, - { - name: "external user already exists", mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{ProjectID: "project"}, nil) - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{{}}, nil) - }, - wantErr: true, - err: domainErrors.ErrExternalUserAlreadyExists, - }, - { - name: "cant find external user", - wantErr: true, - mock: func() { - mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{ProjectID: "project"}, nil) - mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, errors.New("random error")) + mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, nil) + mockRepo.On("Create", mock.Anything, mock.Anything).Return(errors.New("random error")) }, }, { - name: "failure", + name: "create failed to create external user", wantErr: true, mock: func() { mockRepo.ExpectedCalls = []*mock.Call{} - mockRepo.On("Get", mock.Anything, mock.Anything).Return(&user.User{ProjectID: "project"}, nil) mockRepo.On("FindExternalBy", mock.Anything, mock.Anything).Return([]*user.ExternalUser{}, nil) + mockRepo.On("Create", mock.Anything, mock.Anything).Return(nil) mockRepo.On("CreateExternal", mock.Anything, mock.Anything).Return(errors.New("random error")) }, }, @@ -269,13 +99,13 @@ func TestCreateExternalUser(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { tt.mock() - _, err := svc.GetOrCreate(ctx, "project", "user", "provider") + _, err := svc.GetOrCreate(ctx, projectID, externalUserID, providerID) if (err != nil) != tt.wantErr { - t.Errorf("CreateExternal() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("have error = %v, wantErr %v", err != nil, tt.wantErr) return } if tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("CreateExternal() error = %v, wantErr %v", err, tt.err) + t.Errorf("have error = %v, wantErr %v", err, tt.err) return } }) diff --git a/pkg/cypher/cypher.go b/pkg/cypher/cypher.go index e235466..c78e0cb 100644 --- a/pkg/cypher/cypher.go +++ b/pkg/cypher/cypher.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "encoding/base64" "errors" + "github.com/codahale/sss" "go.openfort.xyz/shield/pkg/random" ) From ba4186373354a1c4be57cbb0d4da677fe7cc4aff Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Thu, 11 Jul 2024 12:34:20 +0200 Subject: [PATCH 09/10] chore: add CHANGELOG --- .gitignore | 1 + CHANGELOG.md | 10 ++++++++++ 2 files changed, 11 insertions(+) create mode 100644 CHANGELOG.md diff --git a/.gitignore b/.gitignore index ae1de32..4d89686 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ !go.mod !.golangci.yml !Dockerfile +!CHANGELOG.md !.github/workflows/run-rests.yml !.github/workflows/docker-image.yml diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..9fa6f4c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,10 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [v0.1.11] +### Added +- Encryption Sessions, allow projects to register a on time use session with an encryption part to encrypt/decrypt a secret. \ No newline at end of file From d32d103eb6257354dc35ff4cb5b367e909fc188a Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Thu, 11 Jul 2024 12:42:21 +0200 Subject: [PATCH 10/10] fix: dockerfile --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index f9a0106..f381c22 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.22.0-alpine as builder +FROM golang:1.22.5-alpine as builder RUN apk add --no-cache ca-certificates WORKDIR /app COPY . . @@ -8,5 +8,5 @@ FROM scratch WORKDIR /app COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=builder /app/app /usr/bin/ -COPY internal/infrastructure/adapters/sql/migrations /app/internal/infrastructure/adapters/sql/migrations +COPY internal/adapters/repositories/sql/migrations /app/internal/adapters/repositories/sql/migrations ENTRYPOINT ["app"] \ No newline at end of file