diff --git a/auth/api/grpc/auth/client.go b/auth/api/grpc/auth/client.go index de521c876c..ced2d6926a 100644 --- a/auth/api/grpc/auth/client.go +++ b/auth/api/grpc/auth/client.go @@ -7,6 +7,7 @@ import ( "context" "time" + "github.com/absmach/supermq/auth" grpcapi "github.com/absmach/supermq/auth/api/grpc" grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1" "github.com/go-kit/kit/endpoint" @@ -17,9 +18,11 @@ import ( const authSvcName = "auth.v1.AuthService" type authGrpcClient struct { - authenticate endpoint.Endpoint - authorize endpoint.Endpoint - timeout time.Duration + authenticate endpoint.Endpoint + authenticatePAT endpoint.Endpoint + authorize endpoint.Endpoint + authorizePAT endpoint.Endpoint + timeout time.Duration } var _ grpcAuthV1.AuthServiceClient = (*authGrpcClient)(nil) @@ -35,6 +38,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth decodeIdentifyResponse, grpcAuthV1.AuthNRes{}, ).Endpoint(), + authenticatePAT: kitgrpc.NewClient( + conn, + authSvcName, + "AuthenticatePAT", + encodeIdentifyRequest, + decodeIdentifyPATResponse, + grpcAuthV1.AuthNRes{}, + ).Endpoint(), authorize: kitgrpc.NewClient( conn, authSvcName, @@ -43,6 +54,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth decodeAuthorizeResponse, grpcAuthV1.AuthZRes{}, ).Endpoint(), + authorizePAT: kitgrpc.NewClient( + conn, + authSvcName, + "AuthorizePAT", + encodeAuthorizePATRequest, + decodeAuthorizeResponse, + grpcAuthV1.AuthZRes{}, + ).Endpoint(), timeout: timeout, } } @@ -69,6 +88,23 @@ func decodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{} return authenticateRes{id: res.GetId(), userID: res.GetUserId(), domainID: res.GetDomainId()}, nil } +func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.authenticatePAT(ctx, authenticateReq{token: token.GetToken()}) + if err != nil { + return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err) + } + ir := res.(authenticateRes) + return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID}, nil +} + +func decodeIdentifyPATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) { + res := grpcRes.(*grpcAuthV1.AuthNRes) + return authenticateRes{id: res.GetId(), userID: res.GetUserId()}, nil +} + func (client authGrpcClient) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) { ctx, cancel := context.WithTimeout(ctx, client.timeout) defer cancel() @@ -109,3 +145,37 @@ func encodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{} Object: req.Object, }, nil } + +func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.AuthZPatReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.authorizePAT(ctx, authPATReq{ + userID: req.GetUserId(), + patID: req.GetPatId(), + platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()), + optionalDomainID: req.GetOptionalDomainId(), + optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()), + operation: auth.OperationType(req.GetOperation()), + entityIDs: req.GetEntityIds(), + }) + if err != nil { + return &grpcAuthV1.AuthZRes{}, grpcapi.DecodeError(err) + } + + ar := res.(authorizeRes) + return &grpcAuthV1.AuthZRes{Authorized: ar.authorized, Id: ar.id}, nil +} + +func encodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { + req := grpcReq.(authPATReq) + return &grpcAuthV1.AuthZPatReq{ + UserId: req.userID, + PatId: req.patID, + PlatformEntityType: uint32(req.platformEntityType), + OptionalDomainId: req.optionalDomainID, + OptionalDomainEntityType: uint32(req.optionalDomainEntityType), + Operation: uint32(req.operation), + EntityIds: req.entityIDs, + }, nil +} diff --git a/auth/api/grpc/auth/endpoint.go b/auth/api/grpc/auth/endpoint.go index b4f77e3a89..05516b64e6 100644 --- a/auth/api/grpc/auth/endpoint.go +++ b/auth/api/grpc/auth/endpoint.go @@ -27,6 +27,22 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint { } } +func authenticatePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(authenticateReq) + if err := req.validate(); err != nil { + return authenticateRes{}, err + } + + pat, err := svc.IdentifyPAT(ctx, req.token) + if err != nil { + return authenticateRes{}, err + } + + return authenticateRes{id: pat.ID, userID: pat.User}, nil + } +} + func authorizeEndpoint(svc auth.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(authReq) @@ -50,3 +66,18 @@ func authorizeEndpoint(svc auth.Service) endpoint.Endpoint { return authorizeRes{authorized: true}, nil } } + +func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(authPATReq) + + if err := req.validate(); err != nil { + return authorizeRes{}, err + } + err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.platformEntityType, req.optionalDomainID, req.optionalDomainEntityType, req.operation, req.entityIDs...) + if err != nil { + return authorizeRes{authorized: false}, err + } + return authorizeRes{authorized: true}, nil + } +} diff --git a/auth/api/grpc/auth/endpoint_test.go b/auth/api/grpc/auth/endpoint_test.go index a9b5c75993..8ee61c4e18 100644 --- a/auth/api/grpc/auth/endpoint_test.go +++ b/auth/api/grpc/auth/endpoint_test.go @@ -41,12 +41,15 @@ const ( invalidDuration = 7 * 24 * time.Hour validToken = "valid" inValidToken = "invalid" + validPATToken = "valid" + inValidPATToken = "invalid" validPolicy = "valid" ) var ( domainID = testsutil.GenerateUUID(&testing.T{}) authAddr = fmt.Sprintf("localhost:%d", port) + clientID = testsutil.GenerateUUID(&testing.T{}) ) func startGRPCServer(svc auth.Service, port int) *grpc.Server { @@ -63,8 +66,8 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server { func TestIdentify(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + defer conn.Close() grpcClient := grpcapi.NewAuthClient(conn, time.Second) cases := []struct { @@ -96,20 +99,23 @@ func TestIdentify(t *testing.T) { } for _, tc := range cases { - svcCall := svc.On("Identify", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr) - idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token}) - if idt != nil { - assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt)) - } - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - svcCall.Unset() + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr) + idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token}) + if idt != nil { + assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + }) } } func TestAuthorize(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + defer conn.Close() + grpcClient := grpcapi.NewAuthClient(conn, time.Second) cases := []struct { @@ -219,12 +225,154 @@ func TestAuthorize(t *testing.T) { }, } for _, tc := range cases { - svccall := svc.On("Authorize", mock.Anything, mock.Anything).Return(tc.err) - ar, err := grpcClient.Authorize(context.Background(), tc.authRequest) - if ar != nil { - assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar)) - } - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - svccall.Unset() + t.Run(tc.desc, func(t *testing.T) { + svccall := svc.On("Authorize", mock.Anything, mock.Anything).Return(tc.err) + ar, err := grpcClient.Authorize(context.Background(), tc.authRequest) + if ar != nil { + assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svccall.Unset() + }) + } +} + +func TestIdentifyPAT(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + defer conn.Close() + grpcClient := grpcapi.NewAuthClient(conn, time.Second) + + cases := []struct { + desc string + token string + idt *grpcAuthV1.AuthNRes + svcErr error + err error + }{ + { + desc: "authenticate user with valid user token", + token: validToken, + idt: &grpcAuthV1.AuthNRes{Id: id, UserId: clientID}, + err: nil, + }, + { + desc: "authenticate user with invalid user token", + token: "invalid", + idt: &grpcAuthV1.AuthNRes{}, + svcErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "authenticate user with empty token", + token: "", + idt: &grpcAuthV1.AuthNRes{}, + err: apiutil.ErrBearerToken, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("IdentifyPAT", mock.Anything, tc.token).Return(auth.PAT{ID: id, User: clientID, IssuedAt: time.Now()}, tc.svcErr) + idt, err := grpcClient.AuthenticatePAT(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token}) + if idt != nil { + assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + }) + } +} + +func TestAuthorizePAT(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + defer conn.Close() + + grpcClient := grpcapi.NewAuthClient(conn, time.Second) + cases := []struct { + desc string + token string + authRequest *grpcAuthV1.AuthZPatReq + authResponse *grpcAuthV1.AuthZRes + err error + }{ + { + desc: "authorize user with authorized token", + token: validPATToken, + authRequest: &grpcAuthV1.AuthZPatReq{ + UserId: id, + PatId: id, + PlatformEntityType: uint32(auth.PlatformDomainsScope), + OptionalDomainId: domainID, + OptionalDomainEntityType: uint32(auth.DomainClientsScope), + Operation: uint32(auth.CreateOp), + EntityIds: []string{clientID}, + }, + authResponse: &grpcAuthV1.AuthZRes{Authorized: true}, + err: nil, + }, + { + desc: "authorize user with unauthorized token", + token: inValidPATToken, + authRequest: &grpcAuthV1.AuthZPatReq{ + UserId: id, + PatId: id, + PlatformEntityType: uint32(auth.PlatformDomainsScope), + OptionalDomainId: domainID, + OptionalDomainEntityType: uint32(auth.DomainClientsScope), + Operation: uint32(auth.CreateOp), + EntityIds: []string{clientID}, + }, + authResponse: &grpcAuthV1.AuthZRes{Authorized: false}, + err: svcerr.ErrAuthorization, + }, + { + desc: "authorize user with missing user id", + token: validPATToken, + authRequest: &grpcAuthV1.AuthZPatReq{ + PatId: id, + PlatformEntityType: uint32(auth.PlatformDomainsScope), + OptionalDomainId: domainID, + OptionalDomainEntityType: uint32(auth.DomainClientsScope), + Operation: uint32(auth.CreateOp), + EntityIds: []string{clientID}, + }, + authResponse: &grpcAuthV1.AuthZRes{Authorized: false}, + err: apiutil.ErrMissingUserID, + }, + { + desc: "authorize user with missing pat id", + token: validPATToken, + authRequest: &grpcAuthV1.AuthZPatReq{ + UserId: id, + PlatformEntityType: uint32(auth.PlatformDomainsScope), + OptionalDomainId: domainID, + OptionalDomainEntityType: uint32(auth.DomainClientsScope), + Operation: uint32(auth.CreateOp), + EntityIds: []string{clientID}, + }, + authResponse: &grpcAuthV1.AuthZRes{Authorized: false}, + err: apiutil.ErrMissingPATID, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svccall := svc.On("AuthorizePAT", + mock.Anything, + tc.authRequest.UserId, + tc.authRequest.PatId, + mock.Anything, + tc.authRequest.OptionalDomainId, + mock.Anything, + mock.Anything, + mock.Anything).Return(tc.err) + ar, err := grpcClient.AuthorizePAT(context.Background(), tc.authRequest) + if ar != nil { + assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svccall.Unset() + }) } } diff --git a/auth/api/grpc/auth/requests.go b/auth/api/grpc/auth/requests.go index 37a68becd7..d5fc3bd540 100644 --- a/auth/api/grpc/auth/requests.go +++ b/auth/api/grpc/auth/requests.go @@ -4,6 +4,7 @@ package auth import ( + "github.com/absmach/supermq/auth" "github.com/absmach/supermq/pkg/apiutil" ) @@ -49,3 +50,23 @@ func (req authReq) validate() error { return nil } + +type authPATReq struct { + userID string + patID string + platformEntityType auth.PlatformEntityType + optionalDomainID string + optionalDomainEntityType auth.DomainEntityType + operation auth.OperationType + entityIDs []string +} + +func (req authPATReq) validate() error { + if req.userID == "" { + return apiutil.ErrMissingUserID + } + if req.patID == "" { + return apiutil.ErrMissingPATID + } + return nil +} diff --git a/auth/api/grpc/auth/server.go b/auth/api/grpc/auth/server.go index b5e7e3121c..d8b945cb11 100644 --- a/auth/api/grpc/auth/server.go +++ b/auth/api/grpc/auth/server.go @@ -16,8 +16,10 @@ var _ grpcAuthV1.AuthServiceServer = (*authGrpcServer)(nil) type authGrpcServer struct { grpcAuthV1.UnimplementedAuthServiceServer - authorize kitgrpc.Handler - authenticate kitgrpc.Handler + authorize kitgrpc.Handler + authenticate kitgrpc.Handler + authenticatePAT kitgrpc.Handler + authorizePAT kitgrpc.Handler } // NewAuthServer returns new AuthnServiceServer instance. @@ -34,6 +36,18 @@ func NewAuthServer(svc auth.Service) grpcAuthV1.AuthServiceServer { decodeAuthenticateRequest, encodeAuthenticateResponse, ), + + authenticatePAT: kitgrpc.NewServer( + (authenticatePATEndpoint(svc)), + decodeAuthenticateRequest, + encodeAuthenticatePATResponse, + ), + + authorizePAT: kitgrpc.NewServer( + (authorizePATEndpoint(svc)), + decodeAuthorizePATRequest, + encodeAuthorizeResponse, + ), } } @@ -45,6 +59,14 @@ func (s *authGrpcServer) Authenticate(ctx context.Context, req *grpcAuthV1.AuthN return res.(*grpcAuthV1.AuthNRes), nil } +func (s *authGrpcServer) AuthenticatePAT(ctx context.Context, req *grpcAuthV1.AuthNReq) (*grpcAuthV1.AuthNRes, error) { + _, res, err := s.authenticatePAT.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcAuthV1.AuthNRes), nil +} + func (s *authGrpcServer) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq) (*grpcAuthV1.AuthZRes, error) { _, res, err := s.authorize.ServeGRPC(ctx, req) if err != nil { @@ -63,6 +85,11 @@ func encodeAuthenticateResponse(_ context.Context, grpcRes interface{}) (interfa return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, DomainId: res.domainID}, nil } +func encodeAuthenticatePATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) { + res := grpcRes.(authenticateRes) + return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID}, nil +} + func decodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { req := grpcReq.(*grpcAuthV1.AuthZReq) return authReq{ @@ -81,3 +108,24 @@ func encodeAuthorizeResponse(_ context.Context, grpcRes interface{}) (interface{ res := grpcRes.(authorizeRes) return &grpcAuthV1.AuthZRes{Authorized: res.authorized, Id: res.id}, nil } + +func decodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { + req := grpcReq.(*grpcAuthV1.AuthZPatReq) + return authPATReq{ + userID: req.GetUserId(), + patID: req.GetPatId(), + platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()), + optionalDomainID: req.GetOptionalDomainId(), + optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()), + operation: auth.OperationType(req.GetOperation()), + entityIDs: req.GetEntityIds(), + }, nil +} + +func (s *authGrpcServer) AuthorizePAT(ctx context.Context, req *grpcAuthV1.AuthZPatReq) (*grpcAuthV1.AuthZRes, error) { + _, res, err := s.authorizePAT.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcAuthV1.AuthZRes), nil +} diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index 5052010233..f9d2cf4a7e 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -69,12 +69,14 @@ func (tr testRequest) make() (*http.Response, error) { func newService() (auth.Service, *mocks.KeyRepository) { krepo := new(mocks.KeyRepository) + pRepo := new(mocks.PATSRepository) + hash := new(mocks.Hasher) idProvider := uuid.NewMock() pService := new(policymocks.Service) pEvaluator := new(policymocks.Evaluator) t := jwt.New([]byte(secret)) - return auth.New(krepo, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo + return auth.New(krepo, pRepo, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo } func newServer(svc auth.Service) *httptest.Server { diff --git a/auth/api/http/pats/endpoint.go b/auth/api/http/pats/endpoint.go new file mode 100644 index 0000000000..45e6b3c607 --- /dev/null +++ b/auth/api/http/pats/endpoint.go @@ -0,0 +1,187 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "context" + + "github.com/absmach/supermq/auth" + "github.com/go-kit/kit/endpoint" +) + +func createPATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(createPatReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration, req.Scope) + if err != nil { + return nil, err + } + + return createPatRes{pat}, nil + } +} + +func retrievePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(retrievePatReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.RetrievePAT(ctx, req.token, req.id) + if err != nil { + return nil, err + } + + return retrievePatRes{pat}, nil + } +} + +func updatePATNameEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(updatePatNameReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.UpdatePATName(ctx, req.token, req.id, req.Name) + if err != nil { + return nil, err + } + + return updatePatNameRes{pat}, nil + } +} + +func updatePATDescriptionEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(updatePatDescriptionReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.UpdatePATDescription(ctx, req.token, req.id, req.Description) + if err != nil { + return nil, err + } + + return updatePatDescriptionRes{pat}, nil + } +} + +func listPATSEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listPatsReq) + if err := req.validate(); err != nil { + return nil, err + } + + pm := auth.PATSPageMeta{ + Limit: req.limit, + Offset: req.offset, + } + patsPage, err := svc.ListPATS(ctx, req.token, pm) + if err != nil { + return nil, err + } + + return listPatsRes{patsPage}, nil + } +} + +func deletePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(deletePatReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.DeletePAT(ctx, req.token, req.id); err != nil { + return nil, err + } + + return deletePatRes{}, nil + } +} + +func resetPATSecretEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(resetPatSecretReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.ResetPATSecret(ctx, req.token, req.id, req.Duration) + if err != nil { + return nil, err + } + + return resetPatSecretRes{pat}, nil + } +} + +func revokePATSecretEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokePatSecretReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.RevokePATSecret(ctx, req.token, req.id); err != nil { + return nil, err + } + + return revokePatSecretRes{}, nil + } +} + +func addPATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(addPatScopeEntryReq) + if err := req.validate(); err != nil { + return nil, err + } + + scope, err := svc.AddPATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...) + if err != nil { + return nil, err + } + + return addPatScopeEntryRes{scope}, nil + } +} + +func removePATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(removePatScopeEntryReq) + if err := req.validate(); err != nil { + return nil, err + } + + scope, err := svc.RemovePATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...) + if err != nil { + return nil, err + } + return removePatScopeEntryRes{scope}, nil + } +} + +func clearPATAllScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(clearAllScopeEntryReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.ClearPATAllScopeEntry(ctx, req.token, req.id); err != nil { + return nil, err + } + + return clearAllScopeEntryRes{}, nil + } +} diff --git a/auth/api/http/pats/requests.go b/auth/api/http/pats/requests.go new file mode 100644 index 0000000000..39d69efe11 --- /dev/null +++ b/auth/api/http/pats/requests.go @@ -0,0 +1,303 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "encoding/json" + "strings" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/apiutil" +) + +type createPatReq struct { + token string + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + Scope auth.Scope `json:"scope,omitempty"` +} + +func (cpr *createPatReq) UnmarshalJSON(data []byte) error { + var temp struct { + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Duration string `json:"duration,omitempty"` + Scope auth.Scope `json:"scope,omitempty"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + duration, err := time.ParseDuration(temp.Duration) + if err != nil { + return err + } + cpr.Name = temp.Name + cpr.Description = temp.Description + cpr.Duration = duration + cpr.Scope = temp.Scope + return nil +} + +func (req createPatReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + + if strings.TrimSpace(req.Name) == "" { + return apiutil.ErrMissingName + } + + return nil +} + +type retrievePatReq struct { + token string + id string +} + +func (req retrievePatReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type updatePatNameReq struct { + token string + id string + Name string `json:"name,omitempty"` +} + +func (req updatePatNameReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + if strings.TrimSpace(req.Name) == "" { + return apiutil.ErrMissingName + } + return nil +} + +type updatePatDescriptionReq struct { + token string + id string + Description string `json:"description,omitempty"` +} + +func (req updatePatDescriptionReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + if strings.TrimSpace(req.Description) == "" { + return apiutil.ErrMissingDescription + } + return nil +} + +type listPatsReq struct { + token string + offset uint64 + limit uint64 +} + +func (req listPatsReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + return nil +} + +type deletePatReq struct { + token string + id string +} + +func (req deletePatReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type resetPatSecretReq struct { + token string + id string + Duration time.Duration `json:"duration,omitempty"` +} + +func (rspr *resetPatSecretReq) UnmarshalJSON(data []byte) error { + var temp struct { + Duration string `json:"duration,omitempty"` + } + + err := json.Unmarshal(data, &temp) + if err != nil { + return err + } + rspr.Duration, err = time.ParseDuration(temp.Duration) + if err != nil { + return err + } + return nil +} + +func (req resetPatSecretReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type revokePatSecretReq struct { + token string + id string +} + +func (req revokePatSecretReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type addPatScopeEntryReq struct { + token string + id string + PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` + Operation auth.OperationType `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` +} + +func (apser *addPatScopeEntryReq) UnmarshalJSON(data []byte) error { + var temp struct { + PlatformEntityType string `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` + Operation string `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) + if err != nil { + return err + } + odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) + if err != nil { + return err + } + op, err := auth.ParseOperationType(temp.Operation) + if err != nil { + return err + } + apser.PlatformEntityType = pet + apser.OptionalDomainID = temp.OptionalDomainID + apser.OptionalDomainEntityType = odt + apser.Operation = op + apser.EntityIDs = temp.EntityIDs + return nil +} + +func (req addPatScopeEntryReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type removePatScopeEntryReq struct { + token string + id string + PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` + Operation auth.OperationType `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` +} + +func (rpser *removePatScopeEntryReq) UnmarshalJSON(data []byte) error { + var temp struct { + PlatformEntityType string `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` + Operation string `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) + if err != nil { + return err + } + odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) + if err != nil { + return err + } + op, err := auth.ParseOperationType(temp.Operation) + if err != nil { + return err + } + rpser.PlatformEntityType = pet + rpser.OptionalDomainID = temp.OptionalDomainID + rpser.OptionalDomainEntityType = odt + rpser.Operation = op + rpser.EntityIDs = temp.EntityIDs + return nil +} + +func (req removePatScopeEntryReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type clearAllScopeEntryReq struct { + token string + id string +} + +func (req clearAllScopeEntryReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} diff --git a/auth/api/http/pats/responses.go b/auth/api/http/pats/responses.go new file mode 100644 index 0000000000..fe47f63c71 --- /dev/null +++ b/auth/api/http/pats/responses.go @@ -0,0 +1,194 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "net/http" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/auth" +) + +var ( + _ supermq.Response = (*createPatRes)(nil) + _ supermq.Response = (*retrievePatRes)(nil) + _ supermq.Response = (*updatePatNameRes)(nil) + _ supermq.Response = (*updatePatDescriptionRes)(nil) + _ supermq.Response = (*deletePatRes)(nil) + _ supermq.Response = (*resetPatSecretRes)(nil) + _ supermq.Response = (*revokePatSecretRes)(nil) + _ supermq.Response = (*addPatScopeEntryRes)(nil) + _ supermq.Response = (*removePatScopeEntryRes)(nil) + _ supermq.Response = (*clearAllScopeEntryRes)(nil) +) + +type createPatRes struct { + auth.PAT +} + +func (res createPatRes) Code() int { + return http.StatusCreated +} + +func (res createPatRes) Headers() map[string]string { + return map[string]string{} +} + +func (res createPatRes) Empty() bool { + return false +} + +type retrievePatRes struct { + auth.PAT +} + +func (res retrievePatRes) Code() int { + return http.StatusOK +} + +func (res retrievePatRes) Headers() map[string]string { + return map[string]string{} +} + +func (res retrievePatRes) Empty() bool { + return false +} + +type updatePatNameRes struct { + auth.PAT +} + +func (res updatePatNameRes) Code() int { + return http.StatusAccepted +} + +func (res updatePatNameRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updatePatNameRes) Empty() bool { + return false +} + +type updatePatDescriptionRes struct { + auth.PAT +} + +func (res updatePatDescriptionRes) Code() int { + return http.StatusAccepted +} + +func (res updatePatDescriptionRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updatePatDescriptionRes) Empty() bool { + return false +} + +type listPatsRes struct { + auth.PATSPage +} + +func (res listPatsRes) Code() int { + return http.StatusOK +} + +func (res listPatsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listPatsRes) Empty() bool { + return false +} + +type deletePatRes struct{} + +func (res deletePatRes) Code() int { + return http.StatusNoContent +} + +func (res deletePatRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deletePatRes) Empty() bool { + return true +} + +type resetPatSecretRes struct { + auth.PAT +} + +func (res resetPatSecretRes) Code() int { + return http.StatusOK +} + +func (res resetPatSecretRes) Headers() map[string]string { + return map[string]string{} +} + +func (res resetPatSecretRes) Empty() bool { + return false +} + +type revokePatSecretRes struct{} + +func (res revokePatSecretRes) Code() int { + return http.StatusNoContent +} + +func (res revokePatSecretRes) Headers() map[string]string { + return map[string]string{} +} + +func (res revokePatSecretRes) Empty() bool { + return true +} + +type addPatScopeEntryRes struct { + auth.Scope +} + +func (res addPatScopeEntryRes) Code() int { + return http.StatusOK +} + +func (res addPatScopeEntryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res addPatScopeEntryRes) Empty() bool { + return false +} + +type removePatScopeEntryRes struct { + auth.Scope +} + +func (res removePatScopeEntryRes) Code() int { + return http.StatusOK +} + +func (res removePatScopeEntryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res removePatScopeEntryRes) Empty() bool { + return false +} + +type clearAllScopeEntryRes struct{} + +func (res clearAllScopeEntryRes) Code() int { + return http.StatusOK +} + +func (res clearAllScopeEntryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res clearAllScopeEntryRes) Empty() bool { + return true +} diff --git a/auth/api/http/pats/transport.go b/auth/api/http/pats/transport.go new file mode 100644 index 0000000000..d32ac4064c --- /dev/null +++ b/auth/api/http/pats/transport.go @@ -0,0 +1,300 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "strings" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/internal/api" + "github.com/absmach/supermq/pkg/apiutil" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" +) + +const ( + contentType = "application/json" + defInterval = "30d" + patPrefix = "pat_" +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + mux.Route("/pats", func(r chi.Router) { + r.Post("/", kithttp.NewServer( + createPATEndpoint(svc), + decodeCreatePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Get("/", kithttp.NewServer( + listPATSEndpoint(svc), + decodeListPATSRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Route("/{id}", func(r chi.Router) { + r.Get("/", kithttp.NewServer( + retrievePATEndpoint(svc), + decodeRetrievePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Patch("/name", kithttp.NewServer( + updatePATNameEndpoint(svc), + decodeUpdatePATNameRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Patch("/description", kithttp.NewServer( + updatePATDescriptionEndpoint(svc), + decodeUpdatePATDescriptionRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Delete("/", kithttp.NewServer( + deletePATEndpoint(svc), + decodeDeletePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Route("/secret", func(r chi.Router) { + r.Patch("/reset", kithttp.NewServer( + resetPATSecretEndpoint(svc), + decodeResetPATSecretRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Patch("/revoke", kithttp.NewServer( + revokePATSecretEndpoint(svc), + decodeRevokePATSecretRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + }) + + r.Route("/scope", func(r chi.Router) { + r.Patch("/add", kithttp.NewServer( + addPATScopeEntryEndpoint(svc), + decodeAddPATScopeEntryRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Patch("/remove", kithttp.NewServer( + removePATScopeEntryEndpoint(svc), + decodeRemovePATScopeEntryRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Delete("/", kithttp.NewServer( + clearPATAllScopeEntryEndpoint(svc), + decodeClearPATAllScopeEntryRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + }) + }) + }) + return mux +} + +func decodeCreatePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + req := createPatReq{token: token} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) + } + return req, nil +} + +func decodeRetrievePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + + req := retrievePatReq{ + token: token, + id: chi.URLParam(r, "id"), + } + return req, nil +} + +func decodeUpdatePATNameRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + req := updatePatNameReq{ + token: token, + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeUpdatePATDescriptionRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + req := updatePatDescriptionReq{ + token: token, + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeListPATSRequest(_ context.Context, r *http.Request) (interface{}, error) { + l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + req := listPatsReq{ + token: token, + limit: l, + offset: o, + } + return req, nil +} + +func decodeDeletePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + return deletePatReq{ + token: token, + id: chi.URLParam(r, "id"), + }, nil +} + +func decodeResetPATSecretRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + req := resetPatSecretReq{ + token: token, + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeRevokePATSecretRequest(_ context.Context, r *http.Request) (interface{}, error) { + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + return revokePatSecretReq{ + token: token, + id: chi.URLParam(r, "id"), + }, nil +} + +func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + + req := addPatScopeEntryReq{ + token: token, + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + + req := removePatScopeEntryReq{ + token: token, + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeClearPATAllScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + token := apiutil.ExtractBearerToken(r) + if strings.HasPrefix(token, patPrefix) { + return nil, apiutil.ErrUnsupportedTokenType + } + + return clearAllScopeEntryReq{ + token: token, + id: chi.URLParam(r, "id"), + }, nil +} diff --git a/auth/api/http/transport.go b/auth/api/http/transport.go index 72a447c6e4..c3a8d723c0 100644 --- a/auth/api/http/transport.go +++ b/auth/api/http/transport.go @@ -9,6 +9,7 @@ import ( "github.com/absmach/supermq" "github.com/absmach/supermq/auth" "github.com/absmach/supermq/auth/api/http/keys" + "github.com/absmach/supermq/auth/api/http/pats" "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -18,6 +19,7 @@ func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string) http. mux := chi.NewRouter() mux = keys.MakeHandler(svc, mux, logger) + mux = pats.MakeHandler(svc, mux, logger) mux.Get("/health", supermq.Health("auth", instanceID)) mux.Handle("/metrics", promhttp.Handler()) diff --git a/auth/api/logging.go b/auth/api/logging.go index 77e704cc5b..94bacfd3d7 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -124,3 +124,253 @@ func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy) }(time.Now()) return lm.svc.Authorize(ctx, pr) } + +func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("name", name), + slog.String("description", description), + slog.String("pat_duration", duration.String()), + slog.String("scope", scope.String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Create PAT failed", args...) + return + } + lm.logger.Info("Create PAT completed successfully", args...) + }(time.Now()) + return lm.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (lm *loggingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("name", name), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update PAT name failed", args...) + return + } + lm.logger.Info("Update PAT name completed successfully", args...) + }(time.Now()) + return lm.svc.UpdatePATName(ctx, token, patID, name) +} + +func (lm *loggingMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("description", description), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update PAT description failed", args...) + return + } + lm.logger.Info("Update PAT description completed successfully", args...) + }(time.Now()) + return lm.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (lm *loggingMiddleware) RetrievePAT(ctx context.Context, token, patID string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Retrieve PAT failed", args...) + return + } + lm.logger.Info("Retrieve PAT completed successfully", args...) + }(time.Now()) + return lm.svc.RetrievePAT(ctx, token, patID) +} + +func (lm *loggingMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (pp auth.PATSPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Uint64("limit", pm.Limit), + slog.Uint64("offset", pm.Offset), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List PATS failed", args...) + return + } + lm.logger.Info("List PATS completed successfully", args...) + }(time.Now()) + return lm.svc.ListPATS(ctx, token, pm) +} + +func (lm *loggingMiddleware) DeletePAT(ctx context.Context, token, patID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Delete PAT failed", args...) + return + } + lm.logger.Info("Delete PAT completed successfully", args...) + }(time.Now()) + return lm.svc.DeletePAT(ctx, token, patID) +} + +func (lm *loggingMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("pat_duration", duration.String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Reset PAT secret failed", args...) + return + } + lm.logger.Info("Reset PAT secret completed successfully", args...) + }(time.Now()) + return lm.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (lm *loggingMiddleware) RevokePATSecret(ctx context.Context, token, patID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke PAT secret failed", args...) + return + } + lm.logger.Info("Revoke PAT secret completed successfully", args...) + }(time.Now()) + return lm.svc.RevokePATSecret(ctx, token, patID) +} + +func (lm *loggingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Add entry to PAT scope failed", args...) + return + } + lm.logger.Info("Add entry to PAT scope completed successfully", args...) + }(time.Now()) + return lm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (lm *loggingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Remove entry from PAT scope failed", args...) + return + } + lm.logger.Info("Remove entry from PAT scope completed successfully", args...) + }(time.Now()) + return lm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (lm *loggingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Clear all entry from PAT scope failed", args...) + return + } + lm.logger.Info("Clear all entry from PAT scope completed successfully", args...) + }(time.Now()) + return lm.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Identify PAT failed", args...) + return + } + lm.logger.Info("Identify PAT completed successfully", args...) + }(time.Now()) + return lm.svc.IdentifyPAT(ctx, paToken) +} + +func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Authorize PAT failed complete successfully", args...) + return + } + lm.logger.Info("Authorize PAT completed successfully", args...) + }(time.Now()) + return lm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (lm *loggingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("user_id", userID), + slog.String("pat_id", patID), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Check PAT failed complete successfully", args...) + return + } + lm.logger.Info("Check PAT completed successfully", args...) + }(time.Now()) + return lm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/auth/api/metrics.go b/auth/api/metrics.go index 93ae544b73..081165b3ff 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -74,3 +74,115 @@ func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy) }(time.Now()) return ms.svc.Authorize(ctx, pr) } + +func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "create_pat").Add(1) + ms.latency.With("method", "create_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (ms *metricsMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "update_pat_name").Add(1) + ms.latency.With("method", "update_pat_name").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.UpdatePATName(ctx, token, patID, name) +} + +func (ms *metricsMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "update_pat_description").Add(1) + ms.latency.With("method", "update_pat_description").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (ms *metricsMiddleware) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "retrieve_pat").Add(1) + ms.latency.With("method", "retrieve_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RetrievePAT(ctx, token, patID) +} + +func (ms *metricsMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_pats").Add(1) + ms.latency.With("method", "list_pats").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ListPATS(ctx, token, pm) +} + +func (ms *metricsMiddleware) DeletePAT(ctx context.Context, token, patID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "delete_pat").Add(1) + ms.latency.With("method", "delete_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.DeletePAT(ctx, token, patID) +} + +func (ms *metricsMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "reset_pat_secret").Add(1) + ms.latency.With("method", "reset_pat_secret").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (ms *metricsMiddleware) RevokePATSecret(ctx context.Context, token, patID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_pat_secret").Add(1) + ms.latency.With("method", "revoke_pat_secret").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RevokePATSecret(ctx, token, patID) +} + +func (ms *metricsMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + defer func(begin time.Time) { + ms.counter.With("method", "add_pat_scope_entry").Add(1) + ms.latency.With("method", "add_pat_scope_entry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (ms *metricsMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + defer func(begin time.Time) { + ms.counter.With("method", "remove_pat_scope_entry").Add(1) + ms.latency.With("method", "remove_pat_scope_entry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (ms *metricsMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "clear_pat_all_scope_entry").Add(1) + ms.latency.With("method", "clear_pat_all_scope_entry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "identify_pat").Add(1) + ms.latency.With("method", "identify_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.IdentifyPAT(ctx, paToken) +} + +func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + defer func(begin time.Time) { + ms.counter.With("method", "authorize_pat").Add(1) + ms.latency.With("method", "authorize_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (ms *metricsMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + defer func(begin time.Time) { + ms.counter.With("method", "check_pat").Add(1) + ms.latency.With("method", "check_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/auth/bolt/doc.go b/auth/bolt/doc.go new file mode 100644 index 0000000000..dcd06ac566 --- /dev/null +++ b/auth/bolt/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package bolt contains PAT repository implementations using +// bolt as the underlying database. +package bolt diff --git a/auth/bolt/init.go b/auth/bolt/init.go new file mode 100644 index 0000000000..2be5977dfe --- /dev/null +++ b/auth/bolt/init.go @@ -0,0 +1,21 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package bolt contains PAT repository implementations using +// bolt as the underlying database. +package bolt + +import ( + "github.com/absmach/supermq/pkg/errors" + bolt "go.etcd.io/bbolt" +) + +var errInit = errors.New("failed to initialize BoltDB") + +func Init(tx *bolt.Tx, bucket string) error { + _, err := tx.CreateBucketIfNotExists([]byte(bucket)) + if err != nil { + return errors.Wrap(errInit, err) + } + return nil +} diff --git a/auth/bolt/pat.go b/auth/bolt/pat.go new file mode 100644 index 0000000000..e16f005842 --- /dev/null +++ b/auth/bolt/pat.go @@ -0,0 +1,812 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bolt + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "strings" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + bolt "go.etcd.io/bbolt" +) + +const ( + idKey = "id" + userKey = "user" + nameKey = "name" + descriptionKey = "description" + secretKey = "secret_key" + scopeKey = "scope" + issuedAtKey = "issued_at" + expiresAtKey = "expires_at" + updatedAtKey = "updated_at" + lastUsedAtKey = "last_used_at" + revokedKey = "revoked" + revokedAtKey = "revoked_at" + platformEntitiesKey = "platform_entities" + patKey = "pat" + + keySeparator = ":" + anyID = "*" +) + +var ( + activateValue = []byte{0x00} + revokedValue = []byte{0x01} + entityValue = []byte{0x02} + anyIDValue = []byte{0x03} + selectedIDsValue = []byte{0x04} +) + +type patRepo struct { + db *bolt.DB + bucketName string +} + +// NewPATSRepository instantiates a bolt +// implementation of PAT repository. +func NewPATSRepository(db *bolt.DB, bucketName string) auth.PATSRepository { + return &patRepo{ + db: db, + bucketName: bucketName, + } +} + +func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { + idxKey := []byte(pat.User + keySeparator + patKey + keySeparator + pat.ID) + kv, err := patToKeyValue(pat) + if err != nil { + return err + } + return pr.db.Update(func(tx *bolt.Tx) error { + rootBucket, err := pr.retrieveRootBucket(tx) + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + b, err := pr.createUserBucket(rootBucket, pat.User) + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + for key, value := range kv { + fullKey := []byte(pat.ID + keySeparator + key) + if err := b.Put(fullKey, value); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + if err := rootBucket.Put(idxKey, []byte(pat.ID)); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + return nil + }) +} + +func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) { + prefix := []byte(patID + keySeparator) + kv := map[string][]byte{} + if err := pr.db.View(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) + if err != nil { + return err + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + kv[string(k)] = v + } + return nil + }); err != nil { + return auth.PAT{}, err + } + + return keyValueToPAT(kv) +} + +func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) { + revoked := true + expired := false + keySecret := patID + keySeparator + secretKey + keyRevoked := patID + keySeparator + revokedKey + keyExpiresAt := patID + keySeparator + expiresAtKey + var secretHash string + if err := pr.db.View(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) + if err != nil { + return err + } + secretHash = string(b.Get([]byte(keySecret))) + revoked = bytesToBoolean(b.Get([]byte(keyRevoked))) + expiresAt := bytesToTime(b.Get([]byte(keyExpiresAt))) + expired = time.Now().After(expiresAt) + return nil + }); err != nil { + return "", true, true, err + } + return secretHash, revoked, expired, nil +} + +func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) { + return pr.updatePATField(ctx, userID, patID, nameKey, []byte(name)) +} + +func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) { + return pr.updatePATField(ctx, userID, patID, descriptionKey, []byte(description)) +} + +func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) { + prefix := []byte(patID + keySeparator) + kv := map[string][]byte{} + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+secretKey), []byte(tokenHash)); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+expiresAtKey), timeToBytes(expiryAt)); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + kv[string(k)] = v + } + return nil + }); err != nil { + return auth.PAT{}, err + } + return keyValueToPAT(kv) +} + +func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + prefix := []byte(userID + keySeparator + patKey + keySeparator) + + patIDs := []string{} + if err := pr.db.View(func(tx *bolt.Tx) error { + b, err := pr.retrieveRootBucket(tx) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + if v != nil { + patIDs = append(patIDs, string(v)) + } + } + return nil + }); err != nil { + return auth.PATSPage{}, err + } + + total := len(patIDs) + + var pats []auth.PAT + + patsPage := auth.PATSPage{ + Total: uint64(total), + Limit: pm.Limit, + Offset: pm.Offset, + PATS: pats, + } + + if int(pm.Offset) >= total { + return patsPage, nil + } + + aLimit := pm.Limit + if rLimit := total - int(pm.Offset); int(pm.Limit) > rLimit { + aLimit = uint64(rLimit) + } + + for i := pm.Offset; i < pm.Offset+aLimit; i++ { + if int(i) < total { + pat, err := pr.Retrieve(ctx, userID, patIDs[i]) + if err != nil { + return patsPage, err + } + patsPage.PATS = append(patsPage.PATS, pat) + } + } + + return patsPage, nil +} + +func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error { + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+revokedKey), revokedValue); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+revokedAtKey), timeToBytes(time.Now())); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + return nil + }); err != nil { + return err + } + return nil +} + +func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error { + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+revokedKey), activateValue); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+revokedAtKey), []byte{}); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + return nil + }); err != nil { + return err + } + return nil +} + +func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error { + prefix := []byte(patID + keySeparator) + idxKey := []byte(userID + keySeparator + patKey + keySeparator + patID) + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) + if err != nil { + return err + } + c := b.Cursor() + for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() { + if err := b.Delete(k); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + rb, err := pr.retrieveRootBucket(tx) + if err != nil { + return err + } + if err := rb.Delete(idxKey); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + return nil + }); err != nil { + return err + } + + return nil +} + +func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + prefix := []byte(patID + keySeparator + scopeKey) + rKV := make(map[string][]byte) + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrCreateEntity) + if err != nil { + return err + } + kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return err + } + for key, value := range kv { + fullKey := []byte(patID + keySeparator + key) + if err := b.Put(fullKey, value); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + rKV[string(k)] = v + } + return nil + }); err != nil { + return auth.Scope{}, err + } + + return parseKeyValueToScope(rKV) +} + +func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + if len(entityIDs) == 0 { + return auth.Scope{}, repoerr.ErrMalformedEntity + } + prefix := []byte(patID + keySeparator + scopeKey) + rKV := make(map[string][]byte) + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) + if err != nil { + return err + } + kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return err + } + for key := range kv { + fullKey := []byte(patID + keySeparator + key) + if err := b.Delete(fullKey); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + rKV[string(k)] = v + } + return nil + }); err != nil { + return auth.Scope{}, err + } + return parseKeyValueToScope(rKV) +} + +func (pr *patRepo) CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + return pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + srootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + + rootKey := patID + keySeparator + srootKey + if value := b.Get([]byte(rootKey)); bytes.Equal(value, anyIDValue) { + return nil + } + for _, entity := range entityIDs { + value := b.Get([]byte(rootKey + keySeparator + entity)) + if !bytes.Equal(value, entityValue) { + return repoerr.ErrNotFound + } + } + return nil + }) +} + +func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string) error { + return nil +} + +func (pr *patRepo) updatePATField(_ context.Context, userID, patID, key string, value []byte) (auth.PAT, error) { + prefix := []byte(patID + keySeparator) + kv := map[string][]byte{} + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+key), value); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + kv[string(k)] = v + } + return nil + }); err != nil { + return auth.PAT{}, err + } + return keyValueToPAT(kv) +} + +func (pr *patRepo) createUserBucket(rootBucket *bolt.Bucket, userID string) (*bolt.Bucket, error) { + userBucket, err := rootBucket.CreateBucketIfNotExists([]byte(userID)) + if err != nil { + return nil, errors.Wrap(repoerr.ErrCreateEntity, fmt.Errorf("failed to retrieve or create bucket for user %s : %w", userID, err)) + } + + return userBucket, nil +} + +func (pr *patRepo) retrieveUserBucket(tx *bolt.Tx, userID, patID string, wrap error) (*bolt.Bucket, error) { + rootBucket, err := pr.retrieveRootBucket(tx) + if err != nil { + return nil, errors.Wrap(wrap, err) + } + + vPatID := rootBucket.Get([]byte(userID + keySeparator + patKey + keySeparator + patID)) + if vPatID == nil { + return nil, repoerr.ErrNotFound + } + + userBucket := rootBucket.Bucket([]byte(userID)) + if userBucket == nil { + return nil, errors.Wrap(wrap, fmt.Errorf("user %s not found", userID)) + } + return userBucket, nil +} + +func (pr *patRepo) retrieveRootBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + rootBucket := tx.Bucket([]byte(pr.bucketName)) + if rootBucket == nil { + return nil, fmt.Errorf("bucket %s not found", pr.bucketName) + } + return rootBucket, nil +} + +func patToKeyValue(pat auth.PAT) (map[string][]byte, error) { + kv := map[string][]byte{ + idKey: []byte(pat.ID), + userKey: []byte(pat.User), + nameKey: []byte(pat.Name), + descriptionKey: []byte(pat.Description), + secretKey: []byte(pat.Secret), + issuedAtKey: timeToBytes(pat.IssuedAt), + expiresAtKey: timeToBytes(pat.ExpiresAt), + updatedAtKey: timeToBytes(pat.UpdatedAt), + lastUsedAtKey: timeToBytes(pat.LastUsedAt), + revokedKey: booleanToBytes(pat.Revoked), + revokedAtKey: timeToBytes(pat.RevokedAt), + } + scopeKV, err := scopeToKeyValue(pat.Scope) + if err != nil { + return nil, err + } + for k, v := range scopeKV { + kv[k] = v + } + return kv, nil +} + +func scopeToKeyValue(scope auth.Scope) (map[string][]byte, error) { + kv := map[string][]byte{} + for opType, scopeValue := range scope.Users { + tempKV, err := scopeEntryToKeyValue(auth.PlatformUsersScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) + if err != nil { + return nil, err + } + for k, v := range tempKV { + kv[k] = v + } + } + for opType, scopeValue := range scope.Dashboard { + tempKV, err := scopeEntryToKeyValue(auth.PlatformDashBoardScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) + if err != nil { + return nil, err + } + for k, v := range tempKV { + kv[k] = v + } + } + for opType, scopeValue := range scope.Messaging { + tempKV, err := scopeEntryToKeyValue(auth.PlatformMesagingScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) + if err != nil { + return nil, err + } + for k, v := range tempKV { + kv[k] = v + } + } + for domainID, domainScope := range scope.Domains { + for opType, scopeValue := range domainScope.DomainManagement { + tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, auth.DomainManagementScope, opType, scopeValue.Values()...) + if err != nil { + return nil, errors.Wrap(repoerr.ErrCreateEntity, err) + } + for k, v := range tempKV { + kv[k] = v + } + } + for entityType, scope := range domainScope.Entities { + for opType, scopeValue := range scope { + tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, entityType, opType, scopeValue.Values()...) + if err != nil { + return nil, errors.Wrap(repoerr.ErrCreateEntity, err) + } + for k, v := range tempKV { + kv[k] = v + } + } + } + } + return kv, nil +} + +func scopeEntryToKeyValue(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (map[string][]byte, error) { + if len(entityIDs) == 0 { + return nil, repoerr.ErrMalformedEntity + } + + rootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + if err != nil { + return nil, err + } + if len(entityIDs) == 1 && entityIDs[0] == anyID { + return map[string][]byte{rootKey: anyIDValue}, nil + } + + kv := map[string][]byte{rootKey: selectedIDsValue} + + for _, entryID := range entityIDs { + if entryID == anyID { + return nil, repoerr.ErrMalformedEntity + } + kv[rootKey+keySeparator+entryID] = entityValue + } + + return kv, nil +} + +func scopeRootKey(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType) (string, error) { + op, err := operation.ValidString() + if err != nil { + return "", errors.Wrap(repoerr.ErrMalformedEntity, err) + } + + var rootKey strings.Builder + + rootKey.WriteString(scopeKey) + rootKey.WriteString(keySeparator) + rootKey.WriteString(platformEntityType.String()) + rootKey.WriteString(keySeparator) + + switch platformEntityType { + case auth.PlatformUsersScope: + rootKey.WriteString(op) + case auth.PlatformDashBoardScope: + rootKey.WriteString(op) + case auth.PlatformMesagingScope: + rootKey.WriteString(op) + case auth.PlatformDomainsScope: + if optionalDomainID == "" { + return "", fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) + } + odet, err := optionalDomainEntityType.ValidString() + if err != nil { + return "", errors.Wrap(repoerr.ErrMalformedEntity, err) + } + rootKey.WriteString(optionalDomainID) + rootKey.WriteString(keySeparator) + rootKey.WriteString(odet) + rootKey.WriteString(keySeparator) + rootKey.WriteString(op) + default: + return "", errors.Wrap(repoerr.ErrMalformedEntity, fmt.Errorf("invalid platform entity type %s", platformEntityType.String())) + } + + return rootKey.String(), nil +} + +func keyValueToBasicPAT(kv map[string][]byte) auth.PAT { + var pat auth.PAT + for k, v := range kv { + switch { + case strings.HasSuffix(k, keySeparator+idKey): + pat.ID = string(v) + case strings.HasSuffix(k, keySeparator+userKey): + pat.User = string(v) + case strings.HasSuffix(k, keySeparator+nameKey): + pat.Name = string(v) + case strings.HasSuffix(k, keySeparator+descriptionKey): + pat.Description = string(v) + case strings.HasSuffix(k, keySeparator+issuedAtKey): + pat.IssuedAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+expiresAtKey): + pat.ExpiresAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+updatedAtKey): + pat.UpdatedAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+lastUsedAtKey): + pat.LastUsedAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+revokedKey): + pat.Revoked = bytesToBoolean(v) + case strings.HasSuffix(k, keySeparator+revokedAtKey): + pat.RevokedAt = bytesToTime(v) + } + } + return pat +} + +func keyValueToPAT(kv map[string][]byte) (auth.PAT, error) { + pat := keyValueToBasicPAT(kv) + scope, err := parseKeyValueToScope(kv) + if err != nil { + return auth.PAT{}, err + } + pat.Scope = scope + return pat, nil +} + +func parseKeyValueToScope(kv map[string][]byte) (auth.Scope, error) { + scope := auth.Scope{ + Domains: make(map[string]auth.DomainScope), + } + for key, value := range kv { + if strings.Index(key, keySeparator+scopeKey+keySeparator) > 0 { + keyParts := strings.Split(key, keySeparator) + + platformEntityType, err := auth.ParsePlatformEntityType(keyParts[2]) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + switch platformEntityType { + case auth.PlatformUsersScope: + scope.Users, err = parseOperation(platformEntityType, scope.Users, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + case auth.PlatformDashBoardScope: + scope.Dashboard, err = parseOperation(platformEntityType, scope.Dashboard, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + case auth.PlatformMesagingScope: + scope.Messaging, err = parseOperation(platformEntityType, scope.Messaging, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + case auth.PlatformDomainsScope: + if len(keyParts) < 6 { + return auth.Scope{}, fmt.Errorf("invalid scope key format: %s", key) + } + domainID := keyParts[3] + if scope.Domains == nil { + scope.Domains = make(map[string]auth.DomainScope) + } + if _, ok := scope.Domains[domainID]; !ok { + scope.Domains[domainID] = auth.DomainScope{} + } + domainScope := scope.Domains[domainID] + + entityType := keyParts[4] + + switch entityType { + case auth.DomainManagementScope.String(): + domainScope.DomainManagement, err = parseOperation(platformEntityType, domainScope.DomainManagement, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + default: + etype, err := auth.ParseDomainEntityType(entityType) + if err != nil { + return auth.Scope{}, fmt.Errorf("key %s invalid entity type %s : %w", key, entityType, err) + } + if domainScope.Entities == nil { + domainScope.Entities = make(map[auth.DomainEntityType]auth.OperationScope) + } + if _, ok := domainScope.Entities[etype]; !ok { + domainScope.Entities[etype] = auth.OperationScope{} + } + entityOperationScope := domainScope.Entities[etype] + entityOperationScope, err = parseOperation(platformEntityType, entityOperationScope, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + domainScope.Entities[etype] = entityOperationScope + } + scope.Domains[domainID] = domainScope + default: + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())) + } + } + } + return scope, nil +} + +func parseOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) (auth.OperationScope, error) { + if opScope == nil { + opScope = make(map[auth.OperationType]auth.ScopeValue) + } + + if err := validateOperation(platformEntityType, opScope, key, keyParts, value); err != nil { + return auth.OperationScope{}, err + } + + switch string(value) { + case string(entityValue): + opType, err := auth.ParseOperationType(keyParts[len(keyParts)-2]) + if err != nil { + return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + entityID := keyParts[len(keyParts)-1] + + if _, oValueExists := opScope[opType]; !oValueExists { + opScope[opType] = &auth.SelectedIDs{} + } + oValue := opScope[opType] + if err := oValue.AddValues(entityID); err != nil { + return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity value %v : %w", key, entityID, err) + } + opScope[opType] = oValue + case string(anyIDValue): + opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) + if err != nil { + return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + if oValue, oValueExists := opScope[opType]; oValueExists && oValue != nil { + if _, ok := oValue.(*auth.AnyIDs); !ok { + return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity anyIDs scope value : key already initialized with different type", key) + } + } + opScope[opType] = &auth.AnyIDs{} + case string(selectedIDsValue): + opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) + if err != nil { + return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + oValue, oValueExists := opScope[opType] + if oValueExists && oValue != nil { + if _, ok := oValue.(*auth.SelectedIDs); !ok { + return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity selectedIDs scope value : key already initialized with different type", key) + } + } + if !oValueExists { + opScope[opType] = &auth.SelectedIDs{} + } + default: + return auth.OperationScope{}, fmt.Errorf("key %s have invalid value %v", key, value) + } + return opScope, nil +} + +func validateOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) error { + expectedKeyPartsLength := 0 + switch string(value) { + case string(entityValue): + switch platformEntityType { + case auth.PlatformDomainsScope: + expectedKeyPartsLength = 7 + case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope: + expectedKeyPartsLength = 5 + default: + return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) + } + case string(selectedIDsValue), string(anyIDValue): + switch platformEntityType { + case auth.PlatformDomainsScope: + expectedKeyPartsLength = 6 + case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope: + expectedKeyPartsLength = 4 + default: + return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) + } + default: + return fmt.Errorf("key %s have invalid value %v", key, value) + } + if len(keyParts) != expectedKeyPartsLength { + return fmt.Errorf("invalid scope key format: %s", key) + } + return nil +} + +func timeToBytes(t time.Time) []byte { + timeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timeBytes, uint64(t.Unix())) + return timeBytes +} + +func bytesToTime(b []byte) time.Time { + timeAtSeconds := binary.BigEndian.Uint64(b) + return time.Unix(int64(timeAtSeconds), 0) +} + +func booleanToBytes(b bool) []byte { + if b { + return []byte{1} + } + return []byte{0} +} + +func bytesToBoolean(b []byte) bool { + if len(b) > 1 || b[0] != activateValue[0] { + return true + } + return false +} diff --git a/auth/hasher.go b/auth/hasher.go new file mode 100644 index 0000000000..ada2352bbe --- /dev/null +++ b/auth/hasher.go @@ -0,0 +1,17 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package auth + +// Hasher specifies an API for generating hashes of an arbitrary textual +// content. +// +//go:generate mockery --name Hasher --output=./mocks --filename hasher.go --quiet --note "Copyright (c) Abstract Machines" +type Hasher interface { + // Hash generates the hashed string from plain-text. + Hash(string) (string, error) + + // Compare compares plain-text version to the hashed one. An error should + // indicate failed comparison. + Compare(string, string) error +} diff --git a/auth/hasher/doc.go b/auth/hasher/doc.go new file mode 100644 index 0000000000..98be992262 --- /dev/null +++ b/auth/hasher/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package hasher contains the domain concept definitions needed to +// support Magistrala users password hasher sub-service functionality. +package hasher diff --git a/auth/hasher/hasher.go b/auth/hasher/hasher.go new file mode 100644 index 0000000000..a9e4b2df08 --- /dev/null +++ b/auth/hasher/hasher.go @@ -0,0 +1,86 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package hasher + +import ( + "encoding/base64" + "fmt" + "math/rand" + "strings" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + "golang.org/x/crypto/scrypt" +) + +var ( + errHashToken = errors.New("failed to generate hash for token") + errHashCompare = errors.New("failed to generate hash for given compare string") + errToken = errors.New("given token and hash are not same") + errSalt = errors.New("failed to generate salt") + errInvalidHashStore = errors.New("invalid stored hash format") + errDecode = errors.New("failed to decode") +) + +var _ auth.Hasher = (*bcryptHasher)(nil) + +type bcryptHasher struct{} + +// New instantiates a bcrypt-based hasher implementation. +func New() auth.Hasher { + return &bcryptHasher{} +} + +func (bh *bcryptHasher) Hash(token string) (string, error) { + salt, err := generateSalt(25) + if err != nil { + return "", err + } + // N is kept 16384 to make faster and added large salt, since PAT will be access by automation scripts in high frequency. + hash, err := scrypt.Key([]byte(token), salt, 16384, 8, 1, 32) + if err != nil { + return "", errors.Wrap(errHashToken, err) + } + + return fmt.Sprintf("%s.%s", base64.StdEncoding.EncodeToString(hash), base64.StdEncoding.EncodeToString(salt)), nil +} + +func (bh *bcryptHasher) Compare(plain, hashed string) error { + parts := strings.Split(hashed, ".") + if len(parts) != 2 { + return errInvalidHashStore + } + + actHash, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return errors.Wrap(errDecode, err) + } + + salt, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return errors.Wrap(errDecode, err) + } + + derivedHash, err := scrypt.Key([]byte(plain), salt, 16384, 8, 1, 32) + if err != nil { + return errors.Wrap(errHashCompare, err) + } + + if string(derivedHash) == string(actHash) { + return nil + } + + return errToken +} + +func generateSalt(length int) ([]byte, error) { + rand.New(rand.NewSource(time.Now().UnixNano())) + salt := make([]byte, length) + _, err := rand.Read(salt) + if err != nil { + return nil, errors.Wrap(errSalt, err) + } + return salt, nil +} diff --git a/auth/keys.go b/auth/keys.go index aa21ee48e0..e273119014 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -30,6 +30,8 @@ const ( RecoveryKey // APIKey enables the one to act on behalf of the user. APIKey + // PersonalAccessToken represents token generated by user for automation. + PersonalAccessToken // InvitationKey is a key for inviting new users. InvitationKey ) @@ -44,6 +46,8 @@ func (kt KeyType) String() string { return "recovery" case APIKey: return "API" + case PersonalAccessToken: + return "pat" default: return "unknown" } diff --git a/auth/mocks/hasher.go b/auth/mocks/hasher.go new file mode 100644 index 0000000000..4c4425b257 --- /dev/null +++ b/auth/mocks/hasher.go @@ -0,0 +1,72 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Hasher is an autogenerated mock type for the Hasher type +type Hasher struct { + mock.Mock +} + +// Compare provides a mock function with given fields: _a0, _a1 +func (_m *Hasher) Compare(_a0 string, _a1 string) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Compare") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Hash provides a mock function with given fields: _a0 +func (_m *Hasher) Hash(_a0 string) (string, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Hash") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewHasher creates a new instance of Hasher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewHasher(t interface { + mock.TestingT + Cleanup(func()) +}) *Hasher { + mock := &Hasher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/pats.go b/auth/mocks/pats.go new file mode 100644 index 0000000000..4b920bbd10 --- /dev/null +++ b/auth/mocks/pats.go @@ -0,0 +1,404 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + auth "github.com/absmach/supermq/auth" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// PATS is an autogenerated mock type for the PATS type +type PATS struct { + mock.Mock +} + +// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AddPATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AuthorizePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CheckPAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID +func (_m *PATS) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for ClearPATAllScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope +func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + ret := _m.Called(ctx, token, name, description, duration, scope) + + if len(ret) == 0 { + panic("no return value specified for CreatePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok { + return rf(ctx, token, name, description, duration, scope) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok { + r0 = rf(ctx, token, name, description, duration, scope) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok { + r1 = rf(ctx, token, name, description, duration, scope) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeletePAT provides a mock function with given fields: ctx, token, patID +func (_m *PATS) DeletePAT(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for DeletePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IdentifyPAT provides a mock function with given fields: ctx, paToken +func (_m *PATS) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + ret := _m.Called(ctx, paToken) + + if len(ret) == 0 { + panic("no return value specified for IdentifyPAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok { + return rf(ctx, paToken) + } + if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok { + r0 = rf(ctx, paToken) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, paToken) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListPATS provides a mock function with given fields: ctx, token, pm +func (_m *PATS) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ret := _m.Called(ctx, token, pm) + + if len(ret) == 0 { + panic("no return value specified for ListPATS") + } + + var r0 auth.PATSPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok { + return rf(ctx, token, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok { + r0 = rf(ctx, token, pm) + } else { + r0 = ret.Get(0).(auth.PATSPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok { + r1 = rf(ctx, token, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemovePATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration +func (_m *PATS) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, duration) + + if len(ret) == 0 { + panic("no return value specified for ResetPATSecret") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) (auth.PAT, error)); ok { + return rf(ctx, token, patID, duration) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) auth.PAT); ok { + r0 = rf(ctx, token, patID, duration) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, time.Duration) error); ok { + r1 = rf(ctx, token, patID, duration) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrievePAT provides a mock function with given fields: ctx, userID, patID +func (_m *PATS) RetrievePAT(ctx context.Context, userID string, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RetrievePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, userID, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RevokePATSecret provides a mock function with given fields: ctx, token, patID +func (_m *PATS) RevokePATSecret(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RevokePATSecret") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdatePATDescription provides a mock function with given fields: ctx, token, patID, description +func (_m *PATS) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, description) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATDescription") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, description) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, description) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, description) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdatePATName provides a mock function with given fields: ctx, token, patID, name +func (_m *PATS) UpdatePATName(ctx context.Context, token string, patID string, name string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, name) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATName") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, name) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewPATS creates a new instance of PATS. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPATS(t interface { + mock.TestingT + Cleanup(func()) +}) *PATS { + mock := &PATS{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/patsrepo.go b/auth/mocks/patsrepo.go new file mode 100644 index 0000000000..c15a8752dc --- /dev/null +++ b/auth/mocks/patsrepo.go @@ -0,0 +1,401 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + auth "github.com/absmach/supermq/auth" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// PATSRepository is an autogenerated mock type for the PATSRepository type +type PATSRepository struct { + mock.Mock +} + +// AddScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATSRepository) AddScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AddScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CheckScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATSRepository) CheckScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CheckScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Reactivate provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Reactivate(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Reactivate") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Remove provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Remove(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveAllScopeEntry provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATSRepository) RemoveScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemoveScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Retrieve provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Retrieve(ctx context.Context, userID string, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Retrieve") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, userID, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrieveAll provides a mock function with given fields: ctx, userID, pm +func (_m *PATSRepository) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ret := _m.Called(ctx, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAll") + } + + var r0 auth.PATSPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok { + return rf(ctx, userID, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok { + r0 = rf(ctx, userID, pm) + } else { + r0 = ret.Get(0).(auth.PATSPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok { + r1 = rf(ctx, userID, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrieveSecretAndRevokeStatus provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) RetrieveSecretAndRevokeStatus(ctx context.Context, userID string, patID string) (string, bool, bool, error) { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveSecretAndRevokeStatus") + } + + var r0 string + var r1 bool + var r2 bool + var r3 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (string, bool, bool, error)); ok { + return rf(ctx, userID, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) string); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) bool); ok { + r1 = rf(ctx, userID, patID) + } else { + r1 = ret.Get(1).(bool) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, string) bool); ok { + r2 = rf(ctx, userID, patID) + } else { + r2 = ret.Get(2).(bool) + } + + if rf, ok := ret.Get(3).(func(context.Context, string, string) error); ok { + r3 = rf(ctx, userID, patID) + } else { + r3 = ret.Error(3) + } + + return r0, r1, r2, r3 +} + +// Revoke provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Revoke(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, pat +func (_m *PATSRepository) Save(ctx context.Context, pat auth.PAT) error { + ret := _m.Called(ctx, pat) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, auth.PAT) error); ok { + r0 = rf(ctx, pat) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateDescription provides a mock function with given fields: ctx, userID, patID, description +func (_m *PATSRepository) UpdateDescription(ctx context.Context, userID string, patID string, description string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID, description) + + if len(ret) == 0 { + panic("no return value specified for UpdateDescription") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID, description) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID, description) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, userID, patID, description) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateName provides a mock function with given fields: ctx, userID, patID, name +func (_m *PATSRepository) UpdateName(ctx context.Context, userID string, patID string, name string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID, name) + + if len(ret) == 0 { + panic("no return value specified for UpdateName") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID, name) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, userID, patID, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateTokenHash provides a mock function with given fields: ctx, userID, patID, tokenHash, expiryAt +func (_m *PATSRepository) UpdateTokenHash(ctx context.Context, userID string, patID string, tokenHash string, expiryAt time.Time) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID, tokenHash, expiryAt) + + if len(ret) == 0 { + panic("no return value specified for UpdateTokenHash") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) (auth.PAT, error)); ok { + return rf(ctx, userID, patID, tokenHash, expiryAt) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) auth.PAT); ok { + r0 = rf(ctx, userID, patID, tokenHash, expiryAt) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Time) error); ok { + r1 = rf(ctx, userID, patID, tokenHash, expiryAt) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewPATSRepository creates a new instance of PATSRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPATSRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *PATSRepository { + mock := &PATSRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/service.go b/auth/mocks/service.go index fcfb9a997a..0a01591776 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -12,6 +12,8 @@ import ( mock "github.com/stretchr/testify/mock" policies "github.com/absmach/supermq/pkg/policies" + + time "time" ) // Service is an autogenerated mock type for the Service type @@ -19,6 +21,41 @@ type Service struct { mock.Mock } +// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AddPATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Authorize provides a mock function with given fields: ctx, pr func (_m *Service) Authorize(ctx context.Context, pr policies.Policy) error { ret := _m.Called(ctx, pr) @@ -37,6 +74,120 @@ func (_m *Service) Authorize(ctx context.Context, pr policies.Policy) error { return r0 } +// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AuthorizePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CheckPAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID +func (_m *Service) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for ClearPATAllScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope +func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + ret := _m.Called(ctx, token, name, description, duration, scope) + + if len(ret) == 0 { + panic("no return value specified for CreatePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok { + return rf(ctx, token, name, description, duration, scope) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok { + r0 = rf(ctx, token, name, description, duration, scope) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok { + r1 = rf(ctx, token, name, description, duration, scope) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeletePAT provides a mock function with given fields: ctx, token, patID +func (_m *Service) DeletePAT(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for DeletePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Identify provides a mock function with given fields: ctx, token func (_m *Service) Identify(ctx context.Context, token string) (auth.Key, error) { ret := _m.Called(ctx, token) @@ -65,6 +216,34 @@ func (_m *Service) Identify(ctx context.Context, token string) (auth.Key, error) return r0, r1 } +// IdentifyPAT provides a mock function with given fields: ctx, paToken +func (_m *Service) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + ret := _m.Called(ctx, paToken) + + if len(ret) == 0 { + panic("no return value specified for IdentifyPAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok { + return rf(ctx, paToken) + } + if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok { + r0 = rf(ctx, paToken) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, paToken) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Issue provides a mock function with given fields: ctx, token, key func (_m *Service) Issue(ctx context.Context, token string, key auth.Key) (auth.Token, error) { ret := _m.Called(ctx, token, key) @@ -93,6 +272,97 @@ func (_m *Service) Issue(ctx context.Context, token string, key auth.Key) (auth. return r0, r1 } +// ListPATS provides a mock function with given fields: ctx, token, pm +func (_m *Service) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ret := _m.Called(ctx, token, pm) + + if len(ret) == 0 { + panic("no return value specified for ListPATS") + } + + var r0 auth.PATSPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok { + return rf(ctx, token, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok { + r0 = rf(ctx, token, pm) + } else { + r0 = ret.Get(0).(auth.PATSPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok { + r1 = rf(ctx, token, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemovePATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration +func (_m *Service) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, duration) + + if len(ret) == 0 { + panic("no return value specified for ResetPATSecret") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) (auth.PAT, error)); ok { + return rf(ctx, token, patID, duration) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) auth.PAT); ok { + r0 = rf(ctx, token, patID, duration) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, time.Duration) error); ok { + r1 = rf(ctx, token, patID, duration) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RetrieveKey provides a mock function with given fields: ctx, token, id func (_m *Service) RetrieveKey(ctx context.Context, token string, id string) (auth.Key, error) { ret := _m.Called(ctx, token, id) @@ -121,6 +391,34 @@ func (_m *Service) RetrieveKey(ctx context.Context, token string, id string) (au return r0, r1 } +// RetrievePAT provides a mock function with given fields: ctx, userID, patID +func (_m *Service) RetrievePAT(ctx context.Context, userID string, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RetrievePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, userID, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Revoke provides a mock function with given fields: ctx, token, id func (_m *Service) Revoke(ctx context.Context, token string, id string) error { ret := _m.Called(ctx, token, id) @@ -139,6 +437,80 @@ func (_m *Service) Revoke(ctx context.Context, token string, id string) error { return r0 } +// RevokePATSecret provides a mock function with given fields: ctx, token, patID +func (_m *Service) RevokePATSecret(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RevokePATSecret") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdatePATDescription provides a mock function with given fields: ctx, token, patID, description +func (_m *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, description) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATDescription") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, description) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, description) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, description) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdatePATName provides a mock function with given fields: ctx, token, patID, name +func (_m *Service) UpdatePATName(ctx context.Context, token string, patID string, name string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, name) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATName") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, name) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewService(t interface { diff --git a/auth/pat.go b/auth/pat.go new file mode 100644 index 0000000000..4a168fa0ec --- /dev/null +++ b/auth/pat.go @@ -0,0 +1,804 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/absmach/supermq/pkg/errors" +) + +var errAddEntityToAnyIDs = errors.New("could not add entity id to any ID scope value") + +// Define OperationType. +type OperationType uint32 + +const ( + CreateOp OperationType = iota + ReadOp + ListOp + UpdateOp + DeleteOp + ShareOp + UnshareOp + PublishOp + SubscribeOp +) + +const ( + createOpStr = "create" + readOpStr = "read" + listOpStr = "list" + updateOpStr = "update" + deleteOpStr = "delete" + shareOpStr = "share" + UnshareOpStr = "unshare" + PublishOpStr = "publish" + SubscribeOpStr = "subscribe" +) + +func (ot OperationType) String() string { + switch ot { + case CreateOp: + return createOpStr + case ReadOp: + return readOpStr + case ListOp: + return listOpStr + case UpdateOp: + return updateOpStr + case DeleteOp: + return deleteOpStr + case ShareOp: + return shareOpStr + case UnshareOp: + return UnshareOpStr + case PublishOp: + return PublishOpStr + case SubscribeOp: + return SubscribeOpStr + default: + return fmt.Sprintf("unknown operation type %d", ot) + } +} + +func (ot OperationType) ValidString() (string, error) { + str := ot.String() + if str == fmt.Sprintf("unknown operation type %d", ot) { + return "", errors.New(str) + } + return str, nil +} + +func ParseOperationType(ot string) (OperationType, error) { + switch ot { + case createOpStr: + return CreateOp, nil + case readOpStr: + return ReadOp, nil + case listOpStr: + return ListOp, nil + case updateOpStr: + return UpdateOp, nil + case deleteOpStr: + return DeleteOp, nil + case shareOpStr: + return ShareOp, nil + case UnshareOpStr: + return UnshareOp, nil + case PublishOpStr: + return PublishOp, nil + case SubscribeOpStr: + return SubscribeOp, nil + default: + return 0, fmt.Errorf("unknown operation type %s", ot) + } +} + +func (ot OperationType) MarshalJSON() ([]byte, error) { + return []byte(ot.String()), nil +} + +func (ot OperationType) MarshalText() (text []byte, err error) { + return []byte(ot.String()), nil +} + +func (ot *OperationType) UnmarshalText(data []byte) (err error) { + *ot, err = ParseOperationType(string(data)) + return err +} + +// Define DomainEntityType. +type DomainEntityType uint32 + +const ( + DomainManagementScope DomainEntityType = iota + DomainGroupsScope + DomainChannelsScope + DomainClientsScope + DomainNullScope +) + +const ( + domainManagementScopeStr = "domain_management" + domainGroupsScopeStr = "groups" + domainChannelsScopeStr = "channels" + domainClientsScopeStr = "clients" +) + +func (det DomainEntityType) String() string { + switch det { + case DomainManagementScope: + return domainManagementScopeStr + case DomainGroupsScope: + return domainGroupsScopeStr + case DomainChannelsScope: + return domainChannelsScopeStr + case DomainClientsScope: + return domainClientsScopeStr + default: + return fmt.Sprintf("unknown domain entity type %d", det) + } +} + +func (det DomainEntityType) ValidString() (string, error) { + str := det.String() + if str == fmt.Sprintf("unknown operation type %d", det) { + return "", errors.New(str) + } + return str, nil +} + +func ParseDomainEntityType(det string) (DomainEntityType, error) { + switch det { + case domainManagementScopeStr: + return DomainManagementScope, nil + case domainGroupsScopeStr: + return DomainGroupsScope, nil + case domainChannelsScopeStr: + return DomainChannelsScope, nil + case domainClientsScopeStr: + return DomainClientsScope, nil + default: + return 0, fmt.Errorf("unknown domain entity type %s", det) + } +} + +func (det DomainEntityType) MarshalJSON() ([]byte, error) { + return []byte(det.String()), nil +} + +func (det DomainEntityType) MarshalText() ([]byte, error) { + return []byte(det.String()), nil +} + +func (det *DomainEntityType) UnmarshalText(data []byte) (err error) { + *det, err = ParseDomainEntityType(string(data)) + return err +} + +// Define DomainEntityType. +type PlatformEntityType uint32 + +const ( + PlatformUsersScope PlatformEntityType = iota + PlatformDomainsScope + PlatformDashBoardScope + PlatformMesagingScope +) + +const ( + platformUsersScopeStr = "users" + platformDomainsScopeStr = "domains" + PlatformDashBoardScopeStr = "dashboard" + PlatformMesagingScopeStr = "messaging" +) + +func (pet PlatformEntityType) String() string { + switch pet { + case PlatformUsersScope: + return platformUsersScopeStr + case PlatformDomainsScope: + return platformDomainsScopeStr + case PlatformDashBoardScope: + return PlatformDashBoardScopeStr + case PlatformMesagingScope: + return PlatformMesagingScopeStr + default: + return fmt.Sprintf("unknown platform entity type %d", pet) + } +} + +func (pet PlatformEntityType) ValidString() (string, error) { + str := pet.String() + if str == fmt.Sprintf("unknown platform entity type %d", pet) { + return "", errors.New(str) + } + return str, nil +} + +func ParsePlatformEntityType(pet string) (PlatformEntityType, error) { + switch pet { + case platformUsersScopeStr: + return PlatformUsersScope, nil + case platformDomainsScopeStr: + return PlatformDomainsScope, nil + default: + return 0, fmt.Errorf("unknown platform entity type %s", pet) + } +} + +func (pet PlatformEntityType) MarshalJSON() ([]byte, error) { + return []byte(pet.String()), nil +} + +func (pet PlatformEntityType) MarshalText() (text []byte, err error) { + return []byte(pet.String()), nil +} + +func (pet *PlatformEntityType) UnmarshalText(data []byte) (err error) { + *pet, err = ParsePlatformEntityType(string(data)) + return err +} + +// ScopeValue interface for Any entity ids or for sets of entity ids. +type ScopeValue interface { + Contains(id string) bool + Values() []string + AddValues(ids ...string) error + RemoveValues(ids ...string) error +} + +// AnyIDs implements ScopeValue for any entity id value. +type AnyIDs struct{} + +func (s AnyIDs) Contains(id string) bool { return true } +func (s AnyIDs) Values() []string { return []string{"*"} } +func (s *AnyIDs) AddValues(ids ...string) error { return errAddEntityToAnyIDs } +func (s *AnyIDs) RemoveValues(ids ...string) error { return errAddEntityToAnyIDs } + +// SelectedIDs implements ScopeValue for sets of entity ids. +type SelectedIDs map[string]struct{} + +func (s SelectedIDs) Contains(id string) bool { _, ok := s[id]; return ok } +func (s SelectedIDs) Values() []string { + values := []string{} + for value := range s { + values = append(values, value) + } + return values +} + +func (s *SelectedIDs) AddValues(ids ...string) error { + if *s == nil { + *s = make(SelectedIDs) + } + for _, id := range ids { + (*s)[id] = struct{}{} + } + return nil +} + +func (s *SelectedIDs) RemoveValues(ids ...string) error { + if *s == nil { + return nil + } + for _, id := range ids { + delete(*s, id) + } + return nil +} + +// OperationScope contains map of OperationType with value of AnyIDs or SelectedIDs. +type OperationScope map[OperationType]ScopeValue + +func (os *OperationScope) UnmarshalJSON(data []byte) error { + type tempOperationScope map[OperationType]json.RawMessage + + var tempScope tempOperationScope + if err := json.Unmarshal(data, &tempScope); err != nil { + return err + } + // Initialize the Operations map + *os = OperationScope{} + + for opType, rawMessage := range tempScope { + var stringValue string + var stringArrayValue []string + + // Try to unmarshal as string + if err := json.Unmarshal(rawMessage, &stringValue); err == nil { + if err := os.Add(opType, stringValue); err != nil { + return err + } + continue + } + + // Try to unmarshal as []string + if err := json.Unmarshal(rawMessage, &stringArrayValue); err == nil { + if err := os.Add(opType, stringArrayValue...); err != nil { + return err + } + continue + } + + // If neither unmarshalling succeeded, return an error + return fmt.Errorf("invalid ScopeValue for OperationType %v", opType) + } + + return nil +} + +func (os OperationScope) MarshalJSON() ([]byte, error) { + tempOperationScope := make(map[OperationType]interface{}) + for oType, scope := range os { + value := scope.Values() + if len(value) == 1 && value[0] == "*" { + tempOperationScope[oType] = "*" + continue + } + tempOperationScope[oType] = value + } + + b, err := json.Marshal(tempOperationScope) + if err != nil { + return nil, err + } + return b, nil +} + +func (os *OperationScope) Add(operation OperationType, entityIDs ...string) error { + var value ScopeValue + + if os == nil { + os = &OperationScope{} + } + + if len(entityIDs) == 0 { + return fmt.Errorf("entity ID is missing") + } + switch { + case len(entityIDs) == 1 && entityIDs[0] == "*": + value = &AnyIDs{} + default: + var sids SelectedIDs + for _, entityID := range entityIDs { + if entityID == "*" { + return fmt.Errorf("list contains wildcard") + } + if sids == nil { + sids = make(SelectedIDs) + } + sids[entityID] = struct{}{} + } + value = &sids + } + (*os)[operation] = value + return nil +} + +func (os *OperationScope) Delete(operation OperationType, entityIDs ...string) error { + if os == nil { + return nil + } + + opEntityIDs, exists := (*os)[operation] + if !exists { + return nil + } + + if len(entityIDs) == 0 { + return fmt.Errorf("failed to delete operation %s: entity ID is missing", operation.String()) + } + + switch eIDs := opEntityIDs.(type) { + case *AnyIDs: + if !(len(entityIDs) == 1 && entityIDs[0] == "*") { + return fmt.Errorf("failed to delete operation %s: invalid list", operation.String()) + } + delete((*os), operation) + return nil + case *SelectedIDs: + for _, entityID := range entityIDs { + if !eIDs.Contains(entityID) { + return fmt.Errorf("failed to delete operation %s: invalid entity ID in list", operation.String()) + } + } + for _, entityID := range entityIDs { + delete(*eIDs, entityID) + if len(*eIDs) == 0 { + delete((*os), operation) + } + } + return nil + default: + return fmt.Errorf("failed to delete operation: invalid entity id type %d", operation) + } +} + +func (os *OperationScope) Check(operation OperationType, entityIDs ...string) bool { + if os == nil { + return false + } + + if scopeValue, ok := (*os)[operation]; ok { + if len(entityIDs) == 0 { + _, ok := scopeValue.(*AnyIDs) + return ok + } + for _, entityID := range entityIDs { + if !scopeValue.Contains(entityID) { + return false + } + } + return true + } + + return false +} + +type DomainScope struct { + DomainManagement OperationScope `json:"domain_management,omitempty"` + Entities map[DomainEntityType]OperationScope `json:"entities,omitempty"` +} + +// Add entry in Domain scope. +func (ds *DomainScope) Add(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if ds == nil { + return fmt.Errorf("failed to add domain %s scope: domain_scope is nil and not initialized", domainEntityType) + } + + if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope { + return fmt.Errorf("failed to add domain %d scope: invalid domain entity type", domainEntityType) + } + if domainEntityType == DomainManagementScope { + if err := ds.DomainManagement.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete domain management scope: %w", err) + } + } + + if ds.Entities == nil { + ds.Entities = make(map[DomainEntityType]OperationScope) + } + + opReg, ok := ds.Entities[domainEntityType] + if !ok { + opReg = OperationScope{} + } + + if err := opReg.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add domain %s scope: %w ", domainEntityType.String(), err) + } + ds.Entities[domainEntityType] = opReg + return nil +} + +// Delete entry in Domain scope. +func (ds *DomainScope) Delete(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if ds == nil { + return nil + } + + if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope { + return fmt.Errorf("failed to delete domain %d scope: invalid domain entity type", domainEntityType) + } + if ds.Entities == nil { + return nil + } + + if domainEntityType == DomainManagementScope { + if err := ds.DomainManagement.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete domain management scope: %w", err) + } + } + + os, exists := ds.Entities[domainEntityType] + if !exists { + return nil + } + + if err := os.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete domain %s scope: %w", domainEntityType.String(), err) + } + + if len(os) == 0 { + delete(ds.Entities, domainEntityType) + } + return nil +} + +// Check entry in Domain scope. +func (ds *DomainScope) Check(domainEntityType DomainEntityType, operation OperationType, ids ...string) bool { + if ds.Entities == nil { + return false + } + if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope { + return false + } + if domainEntityType == DomainManagementScope { + return ds.DomainManagement.Check(operation, ids...) + } + os, exists := ds.Entities[domainEntityType] + if !exists { + return false + } + + return os.Check(operation, ids...) +} + +// Example Scope as JSON +// +// { +// "users": { +// "create": ["*"], +// "read": ["*"], +// "list": ["*"], +// "update": ["*"], +// "delete": ["*"] +// }, +// "domains": { +// "domain_1": { +// "entities": { +// "groups": { +// "create": ["*"] // this for all groups in domain +// }, +// "channels": { +// // for particular channel in domain +// "delete": [ +// "channel1", +// "channel2" +// ] +// }, +// "things": { +// "update": ["*"] // this for all things in domain +// } +// } +// } +// } +// } +type Scope struct { + Users OperationScope `json:"users,omitempty"` + Domains map[string]DomainScope `json:"domains,omitempty"` + Dashboard OperationScope `json:"dashboard,omitempty"` + Messaging OperationScope `json:"messaging,omitempty"` +} + +// Add entry in Domain scope. +func (s *Scope) Add(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if s == nil { + return fmt.Errorf("failed to add platform %s scope: scope is nil and not initialized", platformEntityType.String()) + } + switch platformEntityType { + case PlatformUsersScope: + if err := s.Users.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformDashBoardScope: + if err := s.Dashboard.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformMesagingScope: + if err := s.Messaging.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformDomainsScope: + if optionalDomainID == "" { + return fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) + } + if len(s.Domains) == 0 { + s.Domains = make(map[string]DomainScope) + } + + ds, ok := s.Domains[optionalDomainID] + if !ok { + ds = DomainScope{} + } + if err := ds.Add(optionalDomainEntityType, operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err) + } + s.Domains[optionalDomainID] = ds + default: + return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType) + } + return nil +} + +// Delete entry in Domain scope. +func (s *Scope) Delete(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if s == nil { + return nil + } + switch platformEntityType { + case PlatformUsersScope: + if err := s.Users.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformDashBoardScope: + if err := s.Dashboard.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformMesagingScope: + if err := s.Messaging.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformDomainsScope: + if optionalDomainID == "" { + return fmt.Errorf("failed to delete platform %s scope: invalid domain id", platformEntityType.String()) + } + ds, ok := s.Domains[optionalDomainID] + if !ok { + return nil + } + if err := ds.Delete(optionalDomainEntityType, operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err) + } + default: + return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType) + } + return nil +} + +// Check entry in Domain scope. +func (s *Scope) Check(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) bool { + if s == nil { + return false + } + switch platformEntityType { + case PlatformUsersScope: + return s.Users.Check(operation, entityIDs...) + case PlatformDashBoardScope: + return s.Dashboard.Check(operation, entityIDs...) + case PlatformMesagingScope: + return s.Messaging.Check(operation, entityIDs...) + case PlatformDomainsScope: + ds, ok := s.Domains[optionalDomainID] + if !ok { + return false + } + return ds.Check(optionalDomainEntityType, operation, entityIDs...) + default: + return false + } +} + +func (s *Scope) String() string { + str, err := json.Marshal(s) // , "", " ") + if err != nil { + return fmt.Sprintf("failed to convert scope to string: json marshal error :%s", err.Error()) + } + return string(str) +} + +// PAT represents Personal Access Token. +type PAT struct { + ID string `json:"id,omitempty"` + User string `json:"user,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Secret string `json:"secret,omitempty"` + Scope Scope `json:"scope,omitempty"` + IssuedAt time.Time `json:"issued_at,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + LastUsedAt time.Time `json:"last_used_at,omitempty"` + Revoked bool `json:"revoked,omitempty"` + RevokedAt time.Time `json:"revoked_at,omitempty"` +} + +type PATSPageMeta struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` +} +type PATSPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + PATS []PAT `json:"pats"` +} + +func (pat *PAT) String() string { + str, err := json.MarshalIndent(pat, "", " ") + if err != nil { + return fmt.Sprintf("failed to convert PAT to string: json marshal error :%s", err.Error()) + } + return string(str) +} + +// Expired verifies if the key is expired. +func (pat PAT) Expired() bool { + return pat.ExpiresAt.UTC().Before(time.Now().UTC()) +} + +// PATS specifies function which are required for Personal access Token implementation. +//go:generate mockery --name PATS --output=./mocks --filename pats.go --quiet --note "Copyright (c) Abstract Machines" + +type PATS interface { + // Create function creates new PAT for given valid inputs. + CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) + + // UpdateName function updates the name for the given PAT ID. + UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) + + // UpdateDescription function updates the description for the given PAT ID. + UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) + + // Retrieve function retrieves the PAT for given ID. + RetrievePAT(ctx context.Context, userID string, patID string) (PAT, error) + + // List function lists all the PATs for the user. + ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error) + + // Delete function deletes the PAT for given ID. + DeletePAT(ctx context.Context, token, patID string) error + + // ResetSecret function reset the secret and creates new secret for the given ID. + ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) + + // RevokeSecret function revokes the secret for the given ID. + RevokePATSecret(ctx context.Context, token, patID string) error + + // AddScope function adds a new scope entry. + AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + // RemoveScope function removes a scope entry. + RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + // ClearAllScope function removes all scope entry. + ClearPATAllScopeEntry(ctx context.Context, token, patID string) error + + // IdentifyPAT function will valid the secret. + IdentifyPAT(ctx context.Context, paToken string) (PAT, error) + + // AuthorizePAT function will valid the secret and check the given scope exists. + AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error + + // CheckPAT function will check the given scope exists. + CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error +} + +// PATSRepository specifies PATS persistence API. +// +//go:generate mockery --name PATSRepository --output=./mocks --filename patsrepo.go --quiet --note "Copyright (c) Abstract Machines" +type PATSRepository interface { + // Save persists the PAT + Save(ctx context.Context, pat PAT) (err error) + + // Retrieve retrieves users PAT by its unique identifier. + Retrieve(ctx context.Context, userID, patID string) (pat PAT, err error) + + // RetrieveSecretAndRevokeStatus retrieves secret and revoke status of PAT by its unique identifier. + RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) + + // UpdateName updates the name of a PAT. + UpdateName(ctx context.Context, userID, patID, name string) (PAT, error) + + // UpdateDescription updates the description of a PAT. + UpdateDescription(ctx context.Context, userID, patID, description string) (PAT, error) + + // UpdateTokenHash updates the token hash of a PAT. + UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (PAT, error) + + // RetrieveAll retrieves all PATs belongs to userID. + RetrieveAll(ctx context.Context, userID string, pm PATSPageMeta) (pats PATSPage, err error) + + // Revoke PAT with provided ID. + Revoke(ctx context.Context, userID, patID string) error + + // Reactivate PAT with provided ID. + Reactivate(ctx context.Context, userID, patID string) error + + // Remove removes Key with provided ID. + Remove(ctx context.Context, userID, patID string) error + + AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error + + RemoveAllScopeEntry(ctx context.Context, userID, patID string) error +} diff --git a/auth/service.go b/auth/service.go index 9ca2ff0d23..579d8466fd 100644 --- a/auth/service.go +++ b/auth/service.go @@ -5,6 +5,8 @@ package auth import ( "context" + "encoding/base64" + "math/rand" "strings" "time" @@ -12,11 +14,15 @@ import ( "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" + "github.com/google/uuid" ) const ( - recoveryDuration = 5 * time.Minute - defLimit = 100 + recoveryDuration = 5 * time.Minute + defLimit = 100 + randStr = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&&*|+-=" + patPrefix = "pat" + patSecretSeparator = "_" ) var ( @@ -29,6 +35,17 @@ var ( errRetrieve = errors.New("failed to retrieve key data") errIdentify = errors.New("failed to validate token") errPlatform = errors.New("invalid platform id") + + errMalformedPAT = errors.New("malformed personal access token") + errFailedToParseUUID = errors.New("failed to parse string to UUID") + errInvalidLenFor2UUIDs = errors.New("invalid input length for 2 UUID, excepted 32 byte") + errRevokedPAT = errors.New("revoked pat") + errCreatePAT = errors.New("failed to create PAT") + errUpdatePAT = errors.New("failed to update PAT") + errRetrievePAT = errors.New("failed to retrieve PAT") + errDeletePAT = errors.New("failed to delete PAT") + errRevokePAT = errors.New("failed to revoke PAT") + errClearAllScope = errors.New("failed to clear all entry in scope") ) // Authz represents a authorization service. It exposes @@ -75,12 +92,15 @@ type Authn interface { type Service interface { Authn Authz + PATS } var _ Service = (*service)(nil) type service struct { keys KeyRepository + pats PATSRepository + hasher Hasher idProvider supermq.IDProvider evaluator policies.Evaluator policysvc policies.Service @@ -91,10 +111,12 @@ type service struct { } // New instantiates the auth service implementation. -func New(keys KeyRepository, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service { +func New(keys KeyRepository, pats PATSRepository, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service { return &service{ tokenizer: tokenizer, keys: keys, + pats: pats, + hasher: hasher, idProvider: idp, evaluator: policyEvaluator, policysvc: policyService, @@ -434,3 +456,259 @@ func DecodeDomainUserID(domainUserID string) (string, string) { return "", "" } } + +func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + + id, err := svc.idProvider.ID() + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err) + } + secret, hash, err := svc.generateSecretAndHash(key.User, id) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + now := time.Now() + pat := PAT{ + ID: id, + User: key.User, + Name: name, + Description: description, + Secret: hash, + IssuedAt: now, + ExpiresAt: now.Add(duration), + Scope: scope, + } + if err := svc.pats.Save(ctx, pat); err != nil { + return PAT{}, errors.Wrap(errCreatePAT, err) + } + pat.Secret = secret + return pat, nil +} + +func (svc service) UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + pat, err := svc.pats.UpdateName(ctx, key.User, patID, name) + if err != nil { + return PAT{}, errors.Wrap(errUpdatePAT, err) + } + return pat, nil +} + +func (svc service) UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + pat, err := svc.pats.UpdateDescription(ctx, key.User, patID, description) + if err != nil { + return PAT{}, errors.Wrap(errUpdatePAT, err) + } + return pat, nil +} + +func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + + pat, err := svc.pats.Retrieve(ctx, key.User, patID) + if err != nil { + return PAT{}, errors.Wrap(errRetrievePAT, err) + } + return pat, nil +} + +func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PATSPage{}, err + } + patsPage, err := svc.pats.RetrieveAll(ctx, key.User, pm) + if err != nil { + return PATSPage{}, errors.Wrap(errRetrievePAT, err) + } + return patsPage, nil +} + +func (svc service) DeletePAT(ctx context.Context, token, patID string) error { + key, err := svc.Identify(ctx, token) + if err != nil { + return err + } + if err := svc.pats.Remove(ctx, key.User, patID); err != nil { + return errors.Wrap(errDeletePAT, err) + } + return nil +} + +func (svc service) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + + // Generate new HashToken take place here + secret, hash, err := svc.generateSecretAndHash(key.User, patID) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + pat, err := svc.pats.UpdateTokenHash(ctx, key.User, patID, hash, time.Now().Add(duration)) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + if err := svc.pats.Reactivate(ctx, key.User, patID); err != nil { + return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + pat.Secret = secret + pat.Revoked = false + pat.RevokedAt = time.Time{} + return pat, nil +} + +func (svc service) RevokePATSecret(ctx context.Context, token, patID string) error { + key, err := svc.Identify(ctx, token) + if err != nil { + return err + } + + if err := svc.pats.Revoke(ctx, key.User, patID); err != nil { + return errors.Wrap(errRevokePAT, err) + } + return nil +} + +func (svc service) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return Scope{}, err + } + scope, err := svc.pats.AddScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return Scope{}, errors.Wrap(errRevokePAT, err) + } + return scope, nil +} + +func (svc service) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return Scope{}, err + } + scope, err := svc.pats.RemoveScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return Scope{}, err + } + return scope, nil +} + +func (svc service) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + key, err := svc.Identify(ctx, token) + if err != nil { + return err + } + if err := svc.pats.RemoveAllScopeEntry(ctx, key.User, patID); err != nil { + return errors.Wrap(errClearAllScope, err) + } + return nil +} + +func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error) { + parts := strings.Split(secret, patSecretSeparator) + if len(parts) != 3 && parts[0] != patPrefix { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errMalformedPAT) + } + userID, patID, err := decode(parts[1]) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errMalformedPAT) + } + secretHash, revoked, expired, err := svc.pats.RetrieveSecretAndRevokeStatus(ctx, userID.String(), patID.String()) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + if revoked { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedPAT) + } + if expired { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, ErrExpiry) + } + if err := svc.hasher.Compare(secret, secretHash); err != nil { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + return PAT{ID: patID.String(), User: userID.String()}, nil +} + +func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + res, err := svc.RetrievePAT(ctx, userID, patID) + if err != nil { + return err + } + if err := svc.pats.CheckScopeEntry(ctx, res.User, res.ID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + return nil +} + +func (svc service) CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if err := svc.pats.CheckScopeEntry(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + return nil +} + +func (svc service) generateSecretAndHash(userID, patID string) (string, string, error) { + uID, err := uuid.Parse(userID) + if err != nil { + return "", "", errors.Wrap(errFailedToParseUUID, err) + } + pID, err := uuid.Parse(patID) + if err != nil { + return "", "", errors.Wrap(errFailedToParseUUID, err) + } + + secret := patPrefix + patSecretSeparator + encode(uID, pID) + patSecretSeparator + generateRandomString(100) + secretHash, err := svc.hasher.Hash(secret) + return secret, secretHash, err +} + +func encode(userID, patID uuid.UUID) string { + c := append(userID[:], patID[:]...) + return base64.StdEncoding.EncodeToString(c) +} + +func decode(encoded string) (uuid.UUID, uuid.UUID, error) { + data, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return uuid.Nil, uuid.Nil, err + } + + if len(data) != 32 { + return uuid.Nil, uuid.Nil, errInvalidLenFor2UUIDs + } + + var userID, patID uuid.UUID + copy(userID[:], data[:16]) + copy(patID[:], data[16:]) + + return userID, patID, nil +} + +func generateRandomString(n int) string { + letterRunes := []rune(randStr) + rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} diff --git a/auth/service_test.go b/auth/service_test.go index f323389415..18115fd4c7 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -49,12 +49,16 @@ var ( krepo *mocks.KeyRepository pService *policymocks.Service pEvaluator *policymocks.Evaluator + patsrepo *mocks.PATSRepository + hasher *mocks.Hasher ) func newService() (auth.Service, string) { krepo = new(mocks.KeyRepository) pService = new(policymocks.Service) pEvaluator = new(policymocks.Evaluator) + patsrepo = new(mocks.PATSRepository) + hasher = new(mocks.Hasher) idProvider := uuid.NewMock() t := jwt.New([]byte(secret)) @@ -68,7 +72,7 @@ func newService() (auth.Service, string) { } token, _ := t.Issue(key) - return auth.New(krepo, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, patsrepo, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token } func TestIssue(t *testing.T) { diff --git a/auth/tracing/tracing.go b/auth/tracing/tracing.go index 6df82e8cc3..0321eaf3bd 100644 --- a/auth/tracing/tracing.go +++ b/auth/tracing/tracing.go @@ -6,6 +6,7 @@ package tracing import ( "context" "fmt" + "time" "github.com/absmach/supermq/auth" "github.com/absmach/supermq/pkg/policies" @@ -74,3 +75,141 @@ func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy) return tm.svc.Authorize(ctx, pr) } + +func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "create_pat", trace.WithAttributes( + attribute.String("name", name), + attribute.String("description", description), + attribute.String("duration", duration.String()), + attribute.String("scope", scope.String()), + )) + defer span.End() + return tm.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (tm *tracingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "update_pat_name", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("name", name), + )) + defer span.End() + return tm.svc.UpdatePATName(ctx, token, patID, name) +} + +func (tm *tracingMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "update_pat_description", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("description", description), + )) + defer span.End() + return tm.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (tm *tracingMiddleware) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "retrieve_pat", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.RetrievePAT(ctx, token, patID) +} + +func (tm *tracingMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ctx, span := tm.tracer.Start(ctx, "list_pat", trace.WithAttributes( + attribute.Int64("limit", int64(pm.Limit)), + attribute.Int64("offset", int64(pm.Offset)), + )) + defer span.End() + return tm.svc.ListPATS(ctx, token, pm) +} + +func (tm *tracingMiddleware) DeletePAT(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "delete_pat", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.DeletePAT(ctx, token, patID) +} + +func (tm *tracingMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "reset_pat_secret", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("duration", duration.String()), + )) + defer span.End() + return tm.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (tm *tracingMiddleware) RevokePATSecret(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_pat_secret", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.RevokePATSecret(ctx, token, patID) +} + +func (tm *tracingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + ctx, span := tm.tracer.Start(ctx, "add_pat_scope_entry", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (tm *tracingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + ctx, span := tm.tracer.Start(ctx, "remove_pat_scope_entry", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (tm *tracingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope_entry", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "identity_pat") + defer span.End() + return tm.svc.IdentifyPAT(ctx, paToken) +} + +func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + ctx, span := tm.tracer.Start(ctx, "authorize_pat", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (tm *tracingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + ctx, span := tm.tracer.Start(ctx, "check_pat", trace.WithAttributes( + attribute.String("user_id", userID), + attribute.String("patID", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/cmd/auth/main.go b/cmd/auth/main.go index 6c52de48b6..7960c74f03 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -19,15 +19,17 @@ import ( authgrpcapi "github.com/absmach/supermq/auth/api/grpc/auth" tokengrpcapi "github.com/absmach/supermq/auth/api/grpc/token" httpapi "github.com/absmach/supermq/auth/api/http" + "github.com/absmach/supermq/auth/bolt" + "github.com/absmach/supermq/auth/hasher" "github.com/absmach/supermq/auth/jwt" apostgres "github.com/absmach/supermq/auth/postgres" "github.com/absmach/supermq/auth/tracing" + boltclient "github.com/absmach/supermq/internal/clients/bolt" grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1" grpcTokenV1 "github.com/absmach/supermq/internal/grpc/token/v1" smqlog "github.com/absmach/supermq/logger" "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/policies/spicedb" - "github.com/absmach/supermq/pkg/postgres" pgclient "github.com/absmach/supermq/pkg/postgres" "github.com/absmach/supermq/pkg/prometheus" "github.com/absmach/supermq/pkg/server" @@ -39,6 +41,7 @@ import ( "github.com/authzed/grpcutil" "github.com/caarlos0/env/v11" "github.com/jmoiron/sqlx" + "go.etcd.io/bbolt" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -51,6 +54,7 @@ const ( envPrefixHTTP = "SMQ_AUTH_HTTP_" envPrefixGrpc = "SMQ_AUTH_GRPC_" envPrefixDB = "SMQ_AUTH_DB_" + envPrefixPATDB = "SMQ_AUTH_PAT_DB_" defDB = "auth" defSvcHTTPPort = "8189" defSvcGRPCPort = "8181" @@ -131,7 +135,23 @@ func main() { exitCode = 1 return } - svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient) + + boltDBConfig := boltclient.Config{} + if err := env.ParseWithOptions(&boltDBConfig, env.Options{Prefix: envPrefixPATDB}); err != nil { + logger.Error(fmt.Sprintf("failed to parse bolt db config : %s\n", err.Error())) + exitCode = 1 + return + } + + bClient, err := boltclient.Connect(boltDBConfig, bolt.Init) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to bolt db : %s\n", err.Error())) + exitCode = 1 + return + } + defer bClient.Close() + + svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, bClient, boltDBConfig) grpcServerConfig := server.Config{Port: defSvcGRPCPort} if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGrpc}); err != nil { @@ -211,9 +231,11 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch return nil } -func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { - database := postgres.NewDatabase(db, dbConfig, tracer) +func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, bClient *bbolt.DB, bConfig boltclient.Config) auth.Service { + database := pgclient.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) + patsRepo := bolt.NewPATSRepository(bClient, bConfig.Bucket) + hasher := hasher.New() idProvider := uuid.New() pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger) @@ -221,7 +243,7 @@ func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, t := jwt.New([]byte(cfg.SecretKey)) - svc := auth.New(keysRepo, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) + svc := auth.New(keysRepo, patsRepo, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc = api.LoggingMiddleware(svc, logger) counter, latency := prometheus.MakeMetrics("auth", "api") svc = api.MetricsMiddleware(svc, counter, latency) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 6d66eb3227..5320b57386 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -17,6 +17,7 @@ volumes: supermq-mqtt-broker-volume: supermq-spicedb-db-volume: supermq-auth-db-volume: + supermq-pat-db-volume: supermq-domains-db-volume: supermq-invitations-db-volume: supermq-ui-db-volume: @@ -136,6 +137,7 @@ services: - supermq-base-net volumes: - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} + - supermq-pat-db-volume:/supermq-data # Auth gRPC mTLS server certificates - type: bind source: ${SMQ_AUTH_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} diff --git a/go.mod b/go.mod index 63c239d480..b21b776607 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-chi/chi/v5 v5.1.0 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid/v5 v5.3.0 + github.com/google/uuid v1.6.0 github.com/gookit/color v1.5.4 github.com/gorilla/websocket v1.5.3 github.com/hashicorp/vault/api v1.15.0 @@ -44,6 +45,7 @@ require ( github.com/spf13/viper v1.19.0 github.com/sqids/sqids-go v0.4.1 github.com/stretchr/testify v1.10.0 + go.etcd.io/bbolt v1.3.11 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/otel v1.33.0 @@ -98,7 +100,6 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect diff --git a/go.sum b/go.sum index 3e39870e2d..76b159a832 100644 --- a/go.sum +++ b/go.sum @@ -465,6 +465,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0= +go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0 h1:qtFISDHKolvIxzSs0gIaiPUPR0Cucb0F2coHC7ZLdps= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0/go.mod h1:Y+Pop1Q6hCOnETWTW4NROK/q1hv50hM7yDaUTjG8lp8= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU= diff --git a/internal/api/auth.go b/internal/api/auth.go index 3fd4998519..107b5207a2 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -7,6 +7,7 @@ import ( "context" "net/http" + "github.com/absmach/supermq/auth" "github.com/absmach/supermq/pkg/apiutil" smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/go-chi/chi/v5" @@ -14,7 +15,9 @@ import ( type sessionKeyType string -const SessionKey = sessionKeyType("session") +const ( + SessionKey = sessionKeyType("session") +) func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -24,8 +27,10 @@ func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) fun EncodeError(r.Context(), apiutil.ErrBearerToken, w) return } + var resp smqauthn.Session + var err error - resp, err := authn.Authenticate(r.Context(), token) + resp, err = authn.Authenticate(r.Context(), token) if err != nil { EncodeError(r.Context(), err, w) return @@ -38,7 +43,7 @@ func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) fun return } resp.DomainID = domain - resp.DomainUserID = domain + "_" + resp.UserID + resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID) } ctx := context.WithValue(r.Context(), SessionKey, resp) diff --git a/internal/api/common.go b/internal/api/common.go index e576d93dc1..9f28ac852f 100644 --- a/internal/api/common.go +++ b/internal/api/common.go @@ -134,7 +134,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { case errors.Contains(err, svcerr.ErrAuthentication), errors.Contains(err, apiutil.ErrBearerToken), - errors.Contains(err, svcerr.ErrLogin): + errors.Contains(err, svcerr.ErrLogin), + errors.Contains(err, apiutil.ErrUnsupportedTokenType): err = unwrap(err) w.WriteHeader(http.StatusUnauthorized) case errors.Contains(err, svcerr.ErrMalformedEntity), @@ -184,6 +185,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { errors.Contains(err, apiutil.ErrLenSearchQuery), errors.Contains(err, apiutil.ErrMissingDomainID), errors.Contains(err, certs.ErrFailedReadFromPKI), + errors.Contains(err, apiutil.ErrMissingUserID), + errors.Contains(err, apiutil.ErrMissingPATID), errors.Contains(err, apiutil.ErrMissingUsername), errors.Contains(err, apiutil.ErrMissingFirstName), errors.Contains(err, apiutil.ErrMissingLastName), diff --git a/internal/clients/bolt/bolt.go b/internal/clients/bolt/bolt.go new file mode 100644 index 0000000000..8e2afebf97 --- /dev/null +++ b/internal/clients/bolt/bolt.go @@ -0,0 +1,83 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bolt + +import ( + "io/fs" + "strconv" + "time" + + "github.com/absmach/supermq/pkg/errors" + "github.com/caarlos0/env/v11" + bolt "go.etcd.io/bbolt" +) + +var ( + errConfig = errors.New("failed to load BoltDB configuration") + errConnect = errors.New("failed to connect to BoltDB database") + errInit = errors.New("failed to initialize to BoltDB database") +) + +type FileMode fs.FileMode + +func (fm *FileMode) UnmarshalText(text []byte) error { + temp, err := strconv.ParseUint(string(text), 8, 32) + if err != nil { + return err + } + *fm = FileMode(temp) + return nil +} + +// Config contains BoltDB specific parameters. +type Config struct { + FileDirPath string `env:"FILE_DIR_PATH" envDefault:"./supermq-data"` + FileName string `env:"FILE_NAME" envDefault:"supermq-pat.db"` + FileMode FileMode `env:"FILE_MODE" envDefault:"0600"` + Bucket string `env:"BUCKET" envDefault:"supermq"` + Timeout time.Duration `env:"TIMEOUT" envDefault:"0"` +} + +// Setup load configuration from environment and creates new BoltDB. +func Setup(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { + return SetupDB(envPrefix, initFn) +} + +// SetupDB load configuration from environment,. +func SetupDB(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { + cfg := Config{} + if err := env.ParseWithOptions(&cfg, env.Options{Prefix: envPrefix}); err != nil { + return nil, errors.Wrap(errConfig, err) + } + bdb, err := Connect(cfg, initFn) + if err != nil { + return nil, err + } + + return bdb, nil +} + +// Connect establishes connection to the BoltDB. +func Connect(cfg Config, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { + filePath := cfg.FileDirPath + "/" + cfg.FileName + db, err := bolt.Open(filePath, fs.FileMode(cfg.FileMode), nil) + if err != nil { + return nil, errors.Wrap(errConnect, err) + } + if initFn != nil { + if err := Init(db, cfg, initFn); err != nil { + return nil, err + } + } + return db, nil +} + +func Init(db *bolt.DB, cfg Config, initFn func(*bolt.Tx, string) error) error { + if err := db.Update(func(tx *bolt.Tx) error { + return initFn(tx, cfg.Bucket) + }); err != nil { + return errors.Wrap(errInit, err) + } + return nil +} diff --git a/internal/clients/bolt/doc.go b/internal/clients/bolt/doc.go new file mode 100644 index 0000000000..24fc0f92a5 --- /dev/null +++ b/internal/clients/bolt/doc.go @@ -0,0 +1,9 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package BoltDB contains the domain concept definitions needed to support +// Supermq BoltDB database functionality. +// +// It provides the abstraction of the BoltDB database service, which is used +// to configure, setup and connect to the BoltDB database. +package bolt diff --git a/internal/grpc/auth/v1/auth.pb.go b/internal/grpc/auth/v1/auth.pb.go index 9e2e31c237..ec8075357c 100644 --- a/internal/grpc/auth/v1/auth.pb.go +++ b/internal/grpc/auth/v1/auth.pb.go @@ -238,6 +238,99 @@ func (x *AuthZReq) GetObjectType() string { return "" } +type AuthZPatReq struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id + PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id + PlatformEntityType uint32 `protobuf:"varint,3,opt,name=platform_entity_type,json=platformEntityType,proto3" json:"platform_entity_type,omitempty"` // Platform entity type + OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id + OptionalDomainEntityType uint32 `protobuf:"varint,5,opt,name=optional_domain_entity_type,json=optionalDomainEntityType,proto3" json:"optional_domain_entity_type,omitempty"` // Optional domain entity type + Operation uint32 `protobuf:"varint,6,opt,name=operation,proto3" json:"operation,omitempty"` // Operation + EntityIds []string `protobuf:"bytes,7,rep,name=entity_ids,json=entityIds,proto3" json:"entity_ids,omitempty"` // EntityIDs +} + +func (x *AuthZPatReq) Reset() { + *x = AuthZPatReq{} + mi := &file_auth_v1_auth_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthZPatReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthZPatReq) ProtoMessage() {} + +func (x *AuthZPatReq) ProtoReflect() protoreflect.Message { + mi := &file_auth_v1_auth_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthZPatReq.ProtoReflect.Descriptor instead. +func (*AuthZPatReq) Descriptor() ([]byte, []int) { + return file_auth_v1_auth_proto_rawDescGZIP(), []int{3} +} + +func (x *AuthZPatReq) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *AuthZPatReq) GetPatId() string { + if x != nil { + return x.PatId + } + return "" +} + +func (x *AuthZPatReq) GetPlatformEntityType() uint32 { + if x != nil { + return x.PlatformEntityType + } + return 0 +} + +func (x *AuthZPatReq) GetOptionalDomainId() string { + if x != nil { + return x.OptionalDomainId + } + return "" +} + +func (x *AuthZPatReq) GetOptionalDomainEntityType() uint32 { + if x != nil { + return x.OptionalDomainEntityType + } + return 0 +} + +func (x *AuthZPatReq) GetOperation() uint32 { + if x != nil { + return x.Operation + } + return 0 +} + +func (x *AuthZPatReq) GetEntityIds() []string { + if x != nil { + return x.EntityIds + } + return nil +} + type AuthZRes struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -249,7 +342,7 @@ type AuthZRes struct { func (x *AuthZRes) Reset() { *x = AuthZRes{} - mi := &file_auth_v1_auth_proto_msgTypes[3] + mi := &file_auth_v1_auth_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -261,7 +354,7 @@ func (x *AuthZRes) String() string { func (*AuthZRes) ProtoMessage() {} func (x *AuthZRes) ProtoReflect() protoreflect.Message { - mi := &file_auth_v1_auth_proto_msgTypes[3] + mi := &file_auth_v1_auth_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -274,7 +367,7 @@ func (x *AuthZRes) ProtoReflect() protoreflect.Message { // Deprecated: Use AuthZRes.ProtoReflect.Descriptor instead. func (*AuthZRes) Descriptor() ([]byte, []int) { - return file_auth_v1_auth_proto_rawDescGZIP(), []int{3} + return file_auth_v1_auth_proto_rawDescGZIP(), []int{4} } func (x *AuthZRes) GetAuthorized() bool { @@ -321,22 +414,47 @@ var file_auth_v1_auth_proto_rawDesc = []byte{ 0x06, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6f, 0x62, 0x6a, 0x65, - 0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, - 0x65, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x65, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, - 0x69, 0x64, 0x32, 0x7a, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x12, 0x33, 0x0a, 0x09, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, + 0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0x99, 0x02, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x5a, + 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, + 0x15, 0x0a, 0x06, 0x70, 0x61, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x70, 0x61, 0x74, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, + 0x72, 0x6d, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x12, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x45, 0x6e, + 0x74, 0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2c, 0x0a, 0x12, 0x6f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x12, 0x3d, 0x0a, 0x1b, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, + 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x18, 0x6f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64, + 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49, + 0x64, 0x73, 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x12, 0x1e, + 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x0e, + 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x32, 0xf0, + 0x01, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, + 0x0a, 0x09, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75, + 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, + 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, + 0x50, 0x41, 0x54, 0x12, 0x14, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, + 0x74, 0x68, 0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, + 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36, + 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11, + 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, - 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, - 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, - 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, - 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x42, 0x32, - 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x62, 0x73, - 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, - 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x50, 0x41, 0x54, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, + 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, + 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22, + 0x00, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x61, 0x62, 0x73, 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f, + 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, + 0x74, 0x68, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -351,20 +469,25 @@ func file_auth_v1_auth_proto_rawDescGZIP() []byte { return file_auth_v1_auth_proto_rawDescData } -var file_auth_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_auth_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_auth_v1_auth_proto_goTypes = []any{ - (*AuthNReq)(nil), // 0: auth.v1.AuthNReq - (*AuthNRes)(nil), // 1: auth.v1.AuthNRes - (*AuthZReq)(nil), // 2: auth.v1.AuthZReq - (*AuthZRes)(nil), // 3: auth.v1.AuthZRes + (*AuthNReq)(nil), // 0: auth.v1.AuthNReq + (*AuthNRes)(nil), // 1: auth.v1.AuthNRes + (*AuthZReq)(nil), // 2: auth.v1.AuthZReq + (*AuthZPatReq)(nil), // 3: auth.v1.AuthZPatReq + (*AuthZRes)(nil), // 4: auth.v1.AuthZRes } var file_auth_v1_auth_proto_depIdxs = []int32{ 2, // 0: auth.v1.AuthService.Authorize:input_type -> auth.v1.AuthZReq - 0, // 1: auth.v1.AuthService.Authenticate:input_type -> auth.v1.AuthNReq - 3, // 2: auth.v1.AuthService.Authorize:output_type -> auth.v1.AuthZRes - 1, // 3: auth.v1.AuthService.Authenticate:output_type -> auth.v1.AuthNRes - 2, // [2:4] is the sub-list for method output_type - 0, // [0:2] is the sub-list for method input_type + 3, // 1: auth.v1.AuthService.AuthorizePAT:input_type -> auth.v1.AuthZPatReq + 0, // 2: auth.v1.AuthService.Authenticate:input_type -> auth.v1.AuthNReq + 0, // 3: auth.v1.AuthService.AuthenticatePAT:input_type -> auth.v1.AuthNReq + 4, // 4: auth.v1.AuthService.Authorize:output_type -> auth.v1.AuthZRes + 4, // 5: auth.v1.AuthService.AuthorizePAT:output_type -> auth.v1.AuthZRes + 1, // 6: auth.v1.AuthService.Authenticate:output_type -> auth.v1.AuthNRes + 1, // 7: auth.v1.AuthService.AuthenticatePAT:output_type -> auth.v1.AuthNRes + 4, // [4:8] is the sub-list for method output_type + 0, // [0:4] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -381,7 +504,7 @@ func file_auth_v1_auth_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_auth_v1_auth_proto_rawDesc, NumEnums: 0, - NumMessages: 4, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/internal/grpc/auth/v1/auth_grpc.pb.go b/internal/grpc/auth/v1/auth_grpc.pb.go index c7661e872e..d85fabae7c 100644 --- a/internal/grpc/auth/v1/auth_grpc.pb.go +++ b/internal/grpc/auth/v1/auth_grpc.pb.go @@ -22,8 +22,10 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - AuthService_Authorize_FullMethodName = "/auth.v1.AuthService/Authorize" - AuthService_Authenticate_FullMethodName = "/auth.v1.AuthService/Authenticate" + AuthService_Authorize_FullMethodName = "/auth.v1.AuthService/Authorize" + AuthService_AuthorizePAT_FullMethodName = "/auth.v1.AuthService/AuthorizePAT" + AuthService_Authenticate_FullMethodName = "/auth.v1.AuthService/Authenticate" + AuthService_AuthenticatePAT_FullMethodName = "/auth.v1.AuthService/AuthenticatePAT" ) // AuthServiceClient is the client API for AuthService service. @@ -34,7 +36,9 @@ const ( // and authorization functionalities for SuperMQ services. type AuthServiceClient interface { Authorize(ctx context.Context, in *AuthZReq, opts ...grpc.CallOption) (*AuthZRes, error) + AuthorizePAT(ctx context.Context, in *AuthZPatReq, opts ...grpc.CallOption) (*AuthZRes, error) Authenticate(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error) + AuthenticatePAT(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error) } type authServiceClient struct { @@ -55,6 +59,16 @@ func (c *authServiceClient) Authorize(ctx context.Context, in *AuthZReq, opts .. return out, nil } +func (c *authServiceClient) AuthorizePAT(ctx context.Context, in *AuthZPatReq, opts ...grpc.CallOption) (*AuthZRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthZRes) + err := c.cc.Invoke(ctx, AuthService_AuthorizePAT_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(AuthNRes) @@ -65,6 +79,16 @@ func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts return out, nil } +func (c *authServiceClient) AuthenticatePAT(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthNRes) + err := c.cc.Invoke(ctx, AuthService_AuthenticatePAT_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // AuthServiceServer is the server API for AuthService service. // All implementations must embed UnimplementedAuthServiceServer // for forward compatibility. @@ -73,7 +97,9 @@ func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts // and authorization functionalities for SuperMQ services. type AuthServiceServer interface { Authorize(context.Context, *AuthZReq) (*AuthZRes, error) + AuthorizePAT(context.Context, *AuthZPatReq) (*AuthZRes, error) Authenticate(context.Context, *AuthNReq) (*AuthNRes, error) + AuthenticatePAT(context.Context, *AuthNReq) (*AuthNRes, error) mustEmbedUnimplementedAuthServiceServer() } @@ -87,9 +113,15 @@ type UnimplementedAuthServiceServer struct{} func (UnimplementedAuthServiceServer) Authorize(context.Context, *AuthZReq) (*AuthZRes, error) { return nil, status.Errorf(codes.Unimplemented, "method Authorize not implemented") } +func (UnimplementedAuthServiceServer) AuthorizePAT(context.Context, *AuthZPatReq) (*AuthZRes, error) { + return nil, status.Errorf(codes.Unimplemented, "method AuthorizePAT not implemented") +} func (UnimplementedAuthServiceServer) Authenticate(context.Context, *AuthNReq) (*AuthNRes, error) { return nil, status.Errorf(codes.Unimplemented, "method Authenticate not implemented") } +func (UnimplementedAuthServiceServer) AuthenticatePAT(context.Context, *AuthNReq) (*AuthNRes, error) { + return nil, status.Errorf(codes.Unimplemented, "method AuthenticatePAT not implemented") +} func (UnimplementedAuthServiceServer) mustEmbedUnimplementedAuthServiceServer() {} func (UnimplementedAuthServiceServer) testEmbeddedByValue() {} @@ -129,6 +161,24 @@ func _AuthService_Authorize_Handler(srv interface{}, ctx context.Context, dec fu return interceptor(ctx, in, info, handler) } +func _AuthService_AuthorizePAT_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthZPatReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AuthServiceServer).AuthorizePAT(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AuthService_AuthorizePAT_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AuthServiceServer).AuthorizePAT(ctx, req.(*AuthZPatReq)) + } + return interceptor(ctx, in, info, handler) +} + func _AuthService_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(AuthNReq) if err := dec(in); err != nil { @@ -147,6 +197,24 @@ func _AuthService_Authenticate_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _AuthService_AuthenticatePAT_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthNReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AuthServiceServer).AuthenticatePAT(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AuthService_AuthenticatePAT_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AuthServiceServer).AuthenticatePAT(ctx, req.(*AuthNReq)) + } + return interceptor(ctx, in, info, handler) +} + // AuthService_ServiceDesc is the grpc.ServiceDesc for AuthService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -158,10 +226,18 @@ var AuthService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Authorize", Handler: _AuthService_Authorize_Handler, }, + { + MethodName: "AuthorizePAT", + Handler: _AuthService_AuthorizePAT_Handler, + }, { MethodName: "Authenticate", Handler: _AuthService_Authenticate_Handler, }, + { + MethodName: "AuthenticatePAT", + Handler: _AuthService_AuthenticatePAT_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "auth/v1/auth.proto", diff --git a/internal/proto/auth/v1/auth.proto b/internal/proto/auth/v1/auth.proto index 9da75a0688..323eef454d 100644 --- a/internal/proto/auth/v1/auth.proto +++ b/internal/proto/auth/v1/auth.proto @@ -10,7 +10,9 @@ option go_package = "github.com/absmach/supermq/internal/grpc/auth/v1"; // and authorization functionalities for SuperMQ services. service AuthService { rpc Authorize(AuthZReq) returns (AuthZRes) {} + rpc AuthorizePAT(AuthZPatReq) returns (AuthZRes) {} rpc Authenticate(AuthNReq) returns (AuthNRes) {} + rpc AuthenticatePAT(AuthNReq) returns (AuthNRes) {} } @@ -36,6 +38,16 @@ message AuthZReq { string object_type = 9; // Client, User, Group } +message AuthZPatReq { + string user_id = 1; // User id + string pat_id = 2; // Pat id + uint32 platform_entity_type = 3; // Platform entity type + string optional_domain_id = 4; // Optional domain id + uint32 optional_domain_entity_type = 5; // Optional domain entity type + uint32 operation = 6; // Operation + repeated string entity_ids = 7; // EntityIDs +} + message AuthZRes { bool authorized = 1; string id = 2; diff --git a/pkg/apiutil/errors.go b/pkg/apiutil/errors.go index 0459da3f50..3e6bc34cc5 100644 --- a/pkg/apiutil/errors.go +++ b/pkg/apiutil/errors.go @@ -241,4 +241,16 @@ var ( ErrInvalidProfilePictureURL = errors.New("invalid profile picture url") ErrMultipleEntitiesFilter = errors.New("multiple entities are provided in filter are not supported") + + // ErrMissingDescription indicates missing description. + ErrMissingDescription = errors.New("missing description") + + // ErrUnsupportedTokenType indicates that this type of token is not supported. + ErrUnsupportedTokenType = errors.New("unsupported content token type") + + // ErrMissingUserID indicates missing user ID. + ErrMissingUserID = errors.New("missing user id") + + // ErrMissingPATID indicates missing pat ID. + ErrMissingPATID = errors.New("missing pat id") ) diff --git a/pkg/authn/authn.go b/pkg/authn/authn.go index 62663b3734..5bfe92cd0b 100644 --- a/pkg/authn/authn.go +++ b/pkg/authn/authn.go @@ -7,7 +7,29 @@ import ( "context" ) +type TokenType uint32 + +const ( + // AccessToken represents token generated by user. + AccessToken TokenType = iota + // PersonalAccessToken represents token generated by user for automation. + PersonalAccessToken +) + +func (t TokenType) String() string { + switch t { + case AccessToken: + return "access token" + case PersonalAccessToken: + return "pat" + default: + return "unknown" + } +} + type Session struct { + Type TokenType + ID string DomainUserID string UserID string DomainID string diff --git a/pkg/authn/authsvc/authn.go b/pkg/authn/authsvc/authn.go index a49bbd14fe..ee5ef69f51 100644 --- a/pkg/authn/authsvc/authn.go +++ b/pkg/authn/authsvc/authn.go @@ -5,6 +5,7 @@ package authsvc import ( "context" + "strings" "github.com/absmach/supermq/auth/api/grpc/auth" grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1" @@ -14,6 +15,8 @@ import ( grpchealth "google.golang.org/grpc/health/grpc_health_v1" ) +const patPrefix = "pat_" + type authentication struct { authSvcClient grpcAuthV1.AuthServiceClient } @@ -38,9 +41,18 @@ func NewAuthentication(ctx context.Context, cfg grpcclient.Config) (authn.Authen } func (a authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) { + if strings.HasPrefix(token, patPrefix) { + res, err := a.authSvcClient.AuthenticatePAT(ctx, &grpcAuthV1.AuthNReq{Token: token}) + if err != nil { + return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err) + } + + return authn.Session{Type: authn.PersonalAccessToken, ID: res.GetId(), UserID: res.GetUserId()}, nil + } res, err := a.authSvcClient.Authenticate(ctx, &grpcAuthV1.AuthNReq{Token: token}) if err != nil { return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err) } - return authn.Session{DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil + + return authn.Session{Type: authn.AccessToken, DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil }