diff --git a/pkg/gofr/container/datasources.go b/pkg/gofr/container/datasources.go index 15bf02ad7..e88148ad3 100644 --- a/pkg/gofr/container/datasources.go +++ b/pkg/gofr/container/datasources.go @@ -255,7 +255,7 @@ type Mongo interface { CreateCollection(ctx context.Context, name string) error // StartSession starts a session and provide methods to run commands in a transaction. - StartSession() (any, error) + StartSession(ctx context.Context) (any, error) HealthChecker } @@ -285,8 +285,8 @@ type provider interface { // UseTracer sets the tracer for the Cassandra client. UseTracer(tracer any) - // Connect establishes a connection to Cassandra and registers metrics using the provided configuration when the client was Created. - Connect() + // Connect establishes a connection to a DB and registers metrics using the provided configuration when the client was Created. + Connect(ctx context.Context) error } type HealthChecker interface { diff --git a/pkg/gofr/container/mock_datasources.go b/pkg/gofr/container/mock_datasources.go index e88e028ac..4d2e52f68 100644 --- a/pkg/gofr/container/mock_datasources.go +++ b/pkg/gofr/container/mock_datasources.go @@ -8647,15 +8647,17 @@ func (mr *MockCassandraProviderMockRecorder) BatchQueryWithCtx(ctx, name, stmt a } // Connect mocks base method. -func (m *MockCassandraProvider) Connect() { +func (m *MockCassandraProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockCassandraProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockCassandraProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockCassandraProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockCassandraProvider)(nil).Connect), ctx) } // Exec mocks base method. @@ -9059,15 +9061,17 @@ func (mr *MockClickhouseProviderMockRecorder) AsyncInsert(ctx, query, wait any, } // Connect mocks base method. -func (m *MockClickhouseProvider) Connect() { +func (m *MockClickhouseProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockClickhouseProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockClickhouseProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockClickhouseProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockClickhouseProvider)(nil).Connect), ctx) } // Exec mocks base method. @@ -9329,18 +9333,18 @@ func (mr *MockMongoMockRecorder) InsertOne(ctx, collection, document any) *gomoc } // StartSession mocks base method. -func (m *MockMongo) StartSession() (any, error) { +func (m *MockMongo) StartSession(ctx context.Context) (any, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartSession") + ret := m.ctrl.Call(m, "StartSession", ctx) ret0, _ := ret[0].(any) ret1, _ := ret[1].(error) return ret0, ret1 } // StartSession indicates an expected call of StartSession. -func (mr *MockMongoMockRecorder) StartSession() *gomock.Call { +func (mr *MockMongoMockRecorder) StartSession(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockMongo)(nil).StartSession)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockMongo)(nil).StartSession), ctx) } // UpdateByID mocks base method. @@ -9488,15 +9492,17 @@ func (m *MockMongoProvider) EXPECT() *MockMongoProviderMockRecorder { } // Connect mocks base method. -func (m *MockMongoProvider) Connect() { +func (m *MockMongoProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockMongoProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockMongoProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockMongoProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockMongoProvider)(nil).Connect), ctx) } // CountDocuments mocks base method. @@ -9646,18 +9652,18 @@ func (mr *MockMongoProviderMockRecorder) InsertOne(ctx, collection, document any } // StartSession mocks base method. -func (m *MockMongoProvider) StartSession() (any, error) { +func (m *MockMongoProvider) StartSession(ctx context.Context) (any, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartSession") + ret := m.ctrl.Call(m, "StartSession", ctx) ret0, _ := ret[0].(any) ret1, _ := ret[1].(error) return ret0, ret1 } // StartSession indicates an expected call of StartSession. -func (mr *MockMongoProviderMockRecorder) StartSession() *gomock.Call { +func (mr *MockMongoProviderMockRecorder) StartSession(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockMongoProvider)(nil).StartSession)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockMongoProvider)(nil).StartSession), ctx) } // UpdateByID mocks base method. @@ -9764,15 +9770,17 @@ func (m *Mockprovider) EXPECT() *MockproviderMockRecorder { } // Connect mocks base method. -func (m *Mockprovider) Connect() { +func (m *Mockprovider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockproviderMockRecorder) Connect() *gomock.Call { +func (mr *MockproviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*Mockprovider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*Mockprovider)(nil).Connect), ctx) } // UseLogger mocks base method. @@ -9954,15 +9962,17 @@ func (m *MockKVStoreProvider) EXPECT() *MockKVStoreProviderMockRecorder { } // Connect mocks base method. -func (m *MockKVStoreProvider) Connect() { +func (m *MockKVStoreProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockKVStoreProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockKVStoreProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockKVStoreProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockKVStoreProvider)(nil).Connect), ctx) } // Delete mocks base method. @@ -10097,15 +10107,17 @@ func (mr *MockPubSubProviderMockRecorder) Close() *gomock.Call { } // Connect mocks base method. -func (m *MockPubSubProvider) Connect() { +func (m *MockPubSubProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockPubSubProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockPubSubProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockPubSubProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockPubSubProvider)(nil).Connect), ctx) } // CreateTopic mocks base method. @@ -10427,15 +10439,17 @@ func (mr *MockSolrProviderMockRecorder) AddField(ctx, collection, document any) } // Connect mocks base method. -func (m *MockSolrProvider) Connect() { +func (m *MockSolrProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockSolrProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockSolrProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockSolrProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockSolrProvider)(nil).Connect), ctx) } // Create mocks base method. @@ -10772,15 +10786,17 @@ func (mr *MockDgraphProviderMockRecorder) Alter(ctx, op any) *gomock.Call { } // Connect mocks base method. -func (m *MockDgraphProvider) Connect() { +func (m *MockDgraphProvider) Connect(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Connect") + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockDgraphProviderMockRecorder) Connect() *gomock.Call { +func (mr *MockDgraphProviderMockRecorder) Connect(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockDgraphProvider)(nil).Connect)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockDgraphProvider)(nil).Connect), ctx) } // HealthCheck mocks base method. diff --git a/pkg/gofr/datasource/cassandra/cassandra.go b/pkg/gofr/datasource/cassandra/cassandra.go index 76d1fa93f..cc5016e61 100644 --- a/pkg/gofr/datasource/cassandra/cassandra.go +++ b/pkg/gofr/datasource/cassandra/cassandra.go @@ -61,14 +61,14 @@ func New(conf Config) *Client { } // Connect establishes a connection to Cassandra and registers metrics using the provided configuration when the client was Created. -func (c *Client) Connect() { +func (c *Client) Connect(_ context.Context) error { c.logger.Logf("connecting to cassandra at %v on port %v to keyspace %v", c.config.Hosts, c.config.Port, c.config.Keyspace) sess, err := c.cassandra.clusterConfig.createSession() if err != nil { c.logger.Error("error connecting to cassandra: ", err) - return + return err } cassandraBucktes := []float64{.05, .075, .1, .125, .15, .2, .3, .5, .75, 1, 2, 3, 4, 5, 7.5, 10} @@ -77,6 +77,8 @@ func (c *Client) Connect() { c.logger.Logf("connected to '%s' keyspace at host '%s' and port '%d'", c.config.Keyspace, c.config.Hosts, c.config.Port) c.cassandra.session = sess + + return nil } // UseLogger sets the logger for the Cassandra client which asserts the Logger interface. diff --git a/pkg/gofr/datasource/cassandra/cassandra_test.go b/pkg/gofr/datasource/cassandra/cassandra_test.go index 16c87a8ba..2d91091a2 100644 --- a/pkg/gofr/datasource/cassandra/cassandra_test.go +++ b/pkg/gofr/datasource/cassandra/cassandra_test.go @@ -112,7 +112,7 @@ func Test_Connect(t *testing.T) { client.cassandra.clusterConfig = mockClusterConfig - client.Connect() + client.Connect(context.Background()) assert.Equal(t, tc.expSession, client.cassandra.session, "TEST[%d], Failed.\n%s", i, tc.desc) } diff --git a/pkg/gofr/datasource/mongo/errors.go b/pkg/gofr/datasource/mongo/errors.go new file mode 100644 index 000000000..98bcf6684 --- /dev/null +++ b/pkg/gofr/datasource/mongo/errors.go @@ -0,0 +1,17 @@ +package mongo + +import "errors" + +var ( + // ErrInvalidURI is returned when the MongoDB URI is invalid or cannot be parsed. + ErrInvalidURI = errors.New("invalid MongoDB URI") + + // ErrAuthentication is returned when authentication fails. + ErrAuthentication = errors.New("authentication failed") + + // ErrDatabaseConnection is returned when the client fails to connect to the specified database. + ErrDatabaseConnection = errors.New("failed to connect to database") + + // ErrGenericConnection is returned for general connection issues. + ErrGenericConnection = errors.New("MongoDB connection error") +) diff --git a/pkg/gofr/datasource/mongo/go.mod b/pkg/gofr/datasource/mongo/go.mod index 4aa3ddf7a..76959dc98 100644 --- a/pkg/gofr/datasource/mongo/go.mod +++ b/pkg/gofr/datasource/mongo/go.mod @@ -4,10 +4,10 @@ go 1.22 require ( github.com/stretchr/testify v1.9.0 - go.mongodb.org/mongo-driver v1.15.1 - go.opentelemetry.io/otel v1.30.0 - go.opentelemetry.io/otel/trace v1.30.0 - go.uber.org/mock v0.4.0 + go.mongodb.org/mongo-driver v1.17.1 + go.opentelemetry.io/otel v1.31.0 + go.opentelemetry.io/otel/trace v1.31.0 + go.uber.org/mock v0.5.0 ) require ( @@ -15,16 +15,16 @@ require ( github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/klauspost/compress v1.17.8 // indirect + github.com/klauspost/compress v1.17.11 // indirect github.com/montanaflynn/stats v0.7.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect - github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76 // indirect - go.opentelemetry.io/otel/metric v1.30.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/text v0.15.0 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + go.opentelemetry.io/otel/metric v1.31.0 // indirect + golang.org/x/crypto v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/text v0.19.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/gofr/datasource/mongo/go.sum b/pkg/gofr/datasource/mongo/go.sum index 131b1020d..665bfffb1 100644 --- a/pkg/gofr/datasource/mongo/go.sum +++ b/pkg/gofr/datasource/mongo/go.sum @@ -9,8 +9,8 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= -github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -23,31 +23,31 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= -github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76 h1:tBiBTKHnIjovYoLX/TPkcf+OjqqKGQrPtGT3Foz+Pgo= -github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76/go.mod h1:SQliXeA7Dhkt//vS29v3zpbEwoa+zb2Cn5xj5uO4K5U= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mongodb.org/mongo-driver v1.15.1 h1:l+RvoUOoMXFmADTLfYDm7On9dRm7p4T80/lEQM+r7HU= -go.mongodb.org/mongo-driver v1.15.1/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= -go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts= -go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc= -go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w= -go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ= -go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc= -go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.mongodb.org/mongo-driver v1.17.1 h1:Wic5cJIwJgSpBhe3lx3+/RybR5PiYRMpVFgO7cOHyIM= +go.mongodb.org/mongo-driver v1.17.1/go.mod h1:wwWm/+BuOddhcq3n68LKRmgk2wXzmF6s0SFOa0GINL4= +go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= +go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= +go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= +go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= +go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= +go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -59,8 +59,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/pkg/gofr/datasource/mongo/logger_test.go b/pkg/gofr/datasource/mongo/logger_test.go index c9381b760..7c3778c51 100644 --- a/pkg/gofr/datasource/mongo/logger_test.go +++ b/pkg/gofr/datasource/mongo/logger_test.go @@ -19,6 +19,7 @@ func TestLoggingDataPresent(t *testing.T) { expected := "name:John" var buf bytes.Buffer + queryLog.PrettyPrint(&buf) assert.Contains(t, buf.String(), expected) @@ -32,6 +33,7 @@ func TestLoggingEmptyData(t *testing.T) { expected := "name:John" var buf bytes.Buffer + queryLog.PrettyPrint(&buf) assert.NotContains(t, buf.String(), expected) diff --git a/pkg/gofr/datasource/mongo/mongo.go b/pkg/gofr/datasource/mongo/mongo.go index 08c6d9d82..6370772f9 100644 --- a/pkg/gofr/datasource/mongo/mongo.go +++ b/pkg/gofr/datasource/mongo/mongo.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "go.opentelemetry.io/otel/attribute" @@ -40,8 +41,8 @@ var errStatusDown = errors.New("status down") /* Developer Note: We could have accepted logger and metrics as part of the factory function `New`, but when mongo driver is -initialised in GoFr, We want to ensure that the user need not to provides logger and metrics and then connect to the database, -i.e. by default observability features gets initialised when used with GoFr. +initialized in GoFr, We want to ensure that the user need not to provides logger and metrics and then connect to the database, +i.e. by default observability features gets initialized when used with GoFr. */ // New initializes MongoDB driver with the provided configuration. @@ -51,8 +52,8 @@ i.e. by default observability features gets initialised when used with GoFr. // client.UseLogger(loggerInstance) // client.UseMetrics(metricsInstance) // client.Connect(). -func New(c Config) *Client { - return &Client{config: c} +func New(c *Config) *Client { + return &Client{config: *c} } // UseLogger sets the logger for the MongoDB client which asserts the Logger interface. @@ -77,27 +78,103 @@ func (c *Client) UseTracer(tracer any) { } // Connect establishes a connection to MongoDB and registers metrics using the provided configuration when the client was Created. -func (c *Client) Connect() { +func (c *Client) Connect(ctx context.Context) error { c.logger.Logf("connecting to mongoDB at %v to database %v", c.config.URI, c.config.Database) - uri := c.config.URI + uri := c.getURI() - if uri == "" { - uri = fmt.Sprintf("mongodb://%s:%s@%s:%d/%s?authSource=admin", - c.config.User, c.config.Password, c.config.Host, c.config.Port, c.config.Database) + client, err := c.createClient(ctx, uri) + if err != nil { + return err } - m, err := mongo.Connect(context.Background(), options.Client().ApplyURI(uri)) + if err := c.pingDatabase(ctx, client); err != nil { + return err + } + + c.setupMetrics() + + c.Database = client.Database(c.config.Database) + + return c.verifyDatabaseAccess(ctx) +} + +func (c *Client) getURI() string { + if c.config.URI != "" { + return c.config.URI + } + + return fmt.Sprintf("mongodb://%s:%s@%s:%d/%s?authSource=admin", + c.config.User, c.config.Password, c.config.Host, c.config.Port, c.config.Database) +} + +func (c *Client) createClient(ctx context.Context, uri string) (*mongo.Client, error) { + clientOpts := options.Client().ApplyURI(uri) + + client, err := mongo.Connect(ctx, clientOpts) if err != nil { - c.logger.Errorf("error connecting to mongoDB, err:%v", err) + return nil, c.handleConnectionError(err) + } + + return client, nil +} + +func (c *Client) handleConnectionError(err error) error { + if c.isAuthenticationError(err) { + return fmt.Errorf("%w: %w", ErrAuthentication, err) + } + + if c.isTimeoutError(err) { + return fmt.Errorf("%w: connection timeout", ErrGenericConnection) + } + + return fmt.Errorf("%w: %w", ErrGenericConnection, err) +} + +func (*Client) isTimeoutError(err error) bool { + return strings.Contains(err.Error(), "connection timeout") || mongo.IsTimeout(err) +} - return +func (*Client) isAuthenticationError(err error) bool { + return strings.Contains(err.Error(), "authentication failed") || + strings.Contains(err.Error(), "AuthenticationFailed") +} + +func (c *Client) pingDatabase(ctx context.Context, client *mongo.Client) error { + if err := client.Ping(ctx, nil); err != nil { + return c.handlePingError(err) + } + + return nil +} + +func (c *Client) handlePingError(err error) error { + if mongo.IsTimeout(err) { + return fmt.Errorf("%w: connection timeout", ErrGenericConnection) } + if errors.Is(err, mongo.ErrClientDisconnected) { + return fmt.Errorf("%w: client disconnected", ErrGenericConnection) + } + + if c.isAuthenticationError(err) { + return fmt.Errorf("%w: %w", ErrAuthentication, err) + } + + return fmt.Errorf("%w: %w", ErrGenericConnection, err) +} + +func (c *Client) setupMetrics() { mongoBuckets := []float64{.05, .075, .1, .125, .15, .2, .3, .5, .75, 1, 2, 3, 4, 5, 7.5, 10} c.metrics.NewHistogram("app_mongo_stats", "Response time of MONGO queries in milliseconds.", mongoBuckets...) +} + +func (c *Client) verifyDatabaseAccess(ctx context.Context) error { + if err := c.Database.RunCommand(ctx, bson.D{{Key: "ping", Value: 1}}).Err(); err != nil { + return fmt.Errorf("%w: %w", ErrDatabaseConnection, err) + } - c.Database = m.Database(c.config.Database) + return nil } // InsertOne inserts a single document into the specified collection. @@ -106,7 +183,7 @@ func (c *Client) InsertOne(ctx context.Context, collection string, document inte result, err := c.Database.Collection(collection).InsertOne(tracerCtx, document) - defer c.sendOperationStats(&QueryLog{Query: "insertOne", Collection: collection, Filter: document}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "insertOne", Collection: collection, Filter: document}, time.Now(), "insert", span) return result, err @@ -121,7 +198,7 @@ func (c *Client) InsertMany(ctx context.Context, collection string, documents [] return nil, err } - defer c.sendOperationStats(&QueryLog{Query: "insertMany", Collection: collection, Filter: documents}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "insertMany", Collection: collection, Filter: documents}, time.Now(), "insertMany", span) return res.InsertedIDs, nil @@ -136,13 +213,18 @@ func (c *Client) Find(ctx context.Context, collection string, filter, results in return err } - defer cur.Close(ctx) + defer func(cur *mongo.Cursor, ctx context.Context) { + err := cur.Close(ctx) + if err != nil { + c.logger.Errorf("error closing cursor: %v", err) + } + }(cur, ctx) if err := cur.All(ctx, results); err != nil { return err } - defer c.sendOperationStats(&QueryLog{Query: "find", Collection: collection, Filter: filter}, time.Now(), "find", + defer c.sendOperationStats(ctx, &QueryLog{Query: "find", Collection: collection, Filter: filter}, time.Now(), "find", span) return nil @@ -157,7 +239,7 @@ func (c *Client) FindOne(ctx context.Context, collection string, filter, result return err } - defer c.sendOperationStats(&QueryLog{Query: "findOne", Collection: collection, Filter: filter}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "findOne", Collection: collection, Filter: filter}, time.Now(), "findOne", span) return bson.Unmarshal(b, result) @@ -169,10 +251,14 @@ func (c *Client) UpdateByID(ctx context.Context, collection string, id, update i res, err := c.Database.Collection(collection).UpdateByID(tracerCtx, id, update) - defer c.sendOperationStats(&QueryLog{Query: "updateByID", Collection: collection, ID: id, Update: update}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "updateByID", Collection: collection, ID: id, Update: update}, time.Now(), "updateByID", span) - return res.ModifiedCount, err + if err != nil { + return 0, err + } + + return res.ModifiedCount, nil } // UpdateOne updates a single document in the specified collection based on the provided filter. @@ -181,7 +267,7 @@ func (c *Client) UpdateOne(ctx context.Context, collection string, filter, updat _, err := c.Database.Collection(collection).UpdateOne(tracerCtx, filter, update) - defer c.sendOperationStats(&QueryLog{Query: "updateOne", Collection: collection, Filter: filter, Update: update}, + defer c.sendOperationStats(ctx, &QueryLog{Query: "updateOne", Collection: collection, Filter: filter, Update: update}, time.Now(), "updateOne", span) return err @@ -193,10 +279,14 @@ func (c *Client) UpdateMany(ctx context.Context, collection string, filter, upda res, err := c.Database.Collection(collection).UpdateMany(tracerCtx, filter, update) - defer c.sendOperationStats(&QueryLog{Query: "updateMany", Collection: collection, Filter: filter, Update: update}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "updateMany", Collection: collection, Filter: filter, Update: update}, time.Now(), "updateMany", span) - return res.ModifiedCount, err + if err != nil { + return 0, err + } + + return res.ModifiedCount, nil } // CountDocuments counts the number of documents in the specified collection based on the provided filter. @@ -205,7 +295,7 @@ func (c *Client) CountDocuments(ctx context.Context, collection string, filter i result, err := c.Database.Collection(collection).CountDocuments(tracerCtx, filter) - defer c.sendOperationStats(&QueryLog{Query: "countDocuments", Collection: collection, Filter: filter}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "countDocuments", Collection: collection, Filter: filter}, time.Now(), "countDocuments", span) return result, err @@ -220,7 +310,7 @@ func (c *Client) DeleteOne(ctx context.Context, collection string, filter interf return 0, err } - defer c.sendOperationStats(&QueryLog{Query: "deleteOne", Collection: collection, Filter: filter}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "deleteOne", Collection: collection, Filter: filter}, time.Now(), "deleteOne", span) return res.DeletedCount, nil @@ -235,7 +325,7 @@ func (c *Client) DeleteMany(ctx context.Context, collection string, filter inter return 0, err } - defer c.sendOperationStats(&QueryLog{Query: "deleteMany", Collection: collection, Filter: filter}, time.Now(), + defer c.sendOperationStats(ctx, &QueryLog{Query: "deleteMany", Collection: collection, Filter: filter}, time.Now(), "deleteMany", span) return res.DeletedCount, nil @@ -247,7 +337,7 @@ func (c *Client) Drop(ctx context.Context, collection string) error { err := c.Database.Collection(collection).Drop(tracerCtx) - defer c.sendOperationStats(&QueryLog{Query: "drop", Collection: collection}, time.Now(), "drop", span) + defer c.sendOperationStats(ctx, &QueryLog{Query: "drop", Collection: collection}, time.Now(), "drop", span) return err } @@ -258,20 +348,20 @@ func (c *Client) CreateCollection(ctx context.Context, name string) error { err := c.Database.CreateCollection(tracerCtx, name) - defer c.sendOperationStats(&QueryLog{Query: "createCollection", Collection: name}, time.Now(), "createCollection", + defer c.sendOperationStats(ctx, &QueryLog{Query: "createCollection", Collection: name}, time.Now(), "createCollection", span) return err } -func (c *Client) sendOperationStats(ql *QueryLog, startTime time.Time, method string, span trace.Span) { +func (c *Client) sendOperationStats(ctx context.Context, ql *QueryLog, startTime time.Time, method string, span trace.Span) { duration := time.Since(startTime).Milliseconds() ql.Duration = duration c.logger.Debug(ql) - c.metrics.RecordHistogram(context.Background(), "app_mongo_stats", float64(duration), "hostname", c.uri, + c.metrics.RecordHistogram(ctx, "app_mongo_stats", float64(duration), "hostname", c.uri, "database", c.database, "type", ql.Query) if span != nil { @@ -306,8 +396,8 @@ func (c *Client) HealthCheck(ctx context.Context) (any, error) { return &h, nil } -func (c *Client) StartSession() (interface{}, error) { - defer c.sendOperationStats(&QueryLog{Query: "startSession"}, time.Now(), "", nil) +func (c *Client) StartSession(ctx context.Context) (interface{}, error) { + defer c.sendOperationStats(ctx, &QueryLog{Query: "startSession"}, time.Now(), "", nil) s, err := c.Client().StartSession() ses := &session{s} diff --git a/pkg/gofr/datasource/mongo/mongo_test.go b/pkg/gofr/datasource/mongo/mongo_test.go index 2e5e41b69..7821f8a8a 100644 --- a/pkg/gofr/datasource/mongo/mongo_test.go +++ b/pkg/gofr/datasource/mongo/mongo_test.go @@ -2,8 +2,10 @@ package mongo import ( "context" + "errors" "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,16 +24,31 @@ func Test_NewMongoClient(t *testing.T) { metrics := NewMockMetrics(ctrl) logger := NewMockLogger(ctrl) - metrics.EXPECT().NewHistogram("app_mongo_stats", "Response time of MONGO queries in milliseconds.", gomock.Any()) + metrics.EXPECT().NewHistogram("app_mongo_stats", "Response time of MONGO queries in milliseconds.", gomock.Any()).AnyTimes() + logger.EXPECT().Logf("connecting to mongoDB at %v to database %v", gomock.Any(), "test").AnyTimes() - logger.EXPECT().Logf("connecting to mongoDB at %v to database %v", "", "test") - - client := New(Config{Database: "test", Host: "localhost", Port: 27017, User: "admin"}) + client := New(&Config{ + URI: "mongodb://localhost:27017", + Database: "test", + }) client.UseLogger(logger) client.UseMetrics(metrics) - client.Connect() - assert.NotNil(t, client) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := client.Connect(ctx) + if err != nil { + if !errors.Is(err, ErrGenericConnection) { + t.Errorf("Expected ErrGenericConnection, got %v", err) + } + // If MongoDB is not available, this is an acceptable error + t.Logf("Connection failed (this is okay if MongoDB is not running): %v", err) + } else { + assert.NotNil(t, client.Database) + err = client.Database.Client().Disconnect(ctx) + require.NoError(t, err) + } } func Test_NewMongoClientError(t *testing.T) { @@ -41,15 +58,60 @@ func Test_NewMongoClientError(t *testing.T) { metrics := NewMockMetrics(ctrl) logger := NewMockLogger(ctrl) - logger.EXPECT().Logf("connecting to mongoDB at %v to database %v", "mongo", "test") - logger.EXPECT().Errorf("error connecting to mongoDB, err:%v", gomock.Any()) - - client := New(Config{URI: "mongo", Database: "test"}) - client.UseLogger(logger) - client.UseMetrics(metrics) - client.Connect() - - assert.Nil(t, client.Database) + logger.EXPECT().Logf("connecting to mongoDB at %v to database %v", gomock.Any(), "test").AnyTimes() + + testCases := []struct { + name string + config Config + expectedErr error + }{ + { + name: "Invalid URI", + config: Config{ + URI: "invalid://uri", + Database: "test", + }, + expectedErr: ErrGenericConnection, + }, + { + name: "Authentication Error or Timeout", + config: Config{ + URI: "mongodb://wronguser:wrongpass@localhost:27017/test", + Database: "test", + }, + expectedErr: ErrGenericConnection, // This could be ErrAuthentication or ErrGenericConnection (timeout) + }, + { + name: "Database Connection Error", + config: Config{ + URI: "mongodb://localhost:27018/test", // Using wrong port + Database: "test", + }, + expectedErr: ErrGenericConnection, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := New(&tc.config) + client.UseLogger(logger) + client.UseMetrics(metrics) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := client.Connect(ctx) + + require.Error(t, err) + + if !errors.Is(err, tc.expectedErr) { + t.Errorf("Expected error type %T, got %T", tc.expectedErr, err) + t.Errorf("Expected error to be wrapped with %v, but it wasn't", tc.expectedErr) + } + + t.Logf("Received error: %v", err) // Log the full error message + }) + } } func Test_InsertCommands(t *testing.T) { @@ -64,10 +126,10 @@ func Test_InsertCommands(t *testing.T) { cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")} - metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname", - gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(4) + metrics.EXPECT().RecordHistogram(gomock.Any(), "app_mongo_stats", gomock.Any(), "hostname", + gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).AnyTimes() - logger.EXPECT().Debug(gomock.Any()).Times(4) + logger.EXPECT().Debug(gomock.Any()).AnyTimes() cl.logger = logger @@ -96,7 +158,7 @@ func Test_InsertCommands(t *testing.T) { resp, err := cl.InsertOne(context.Background(), mt.Coll.Name(), doc) assert.Nil(t, resp) - assert.NotNil(t, err) + assert.Error(t, err) }) mt.Run("insertManySuccess", func(mt *mtest.T) { @@ -169,10 +231,10 @@ func Test_FindMultipleCommands(t *testing.T) { cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")} - metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname", - gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(3) + metrics.EXPECT().RecordHistogram(gomock.Any(), "app_mongo_stats", gomock.Any(), "hostname", + gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).AnyTimes() - logger.EXPECT().Debug(gomock.Any()).Times(3) + logger.EXPECT().Debug(gomock.Any()).AnyTimes() cl.logger = logger @@ -243,10 +305,10 @@ func Test_FindOneCommands(t *testing.T) { cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")} - metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname", - gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(2) + metrics.EXPECT().RecordHistogram(gomock.Any(), "app_mongo_stats", gomock.Any(), "hostname", + gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).AnyTimes() - logger.EXPECT().Debug(gomock.Any()).Times(2) + logger.EXPECT().Debug(gomock.Any()).AnyTimes() cl.logger = logger @@ -294,12 +356,11 @@ func Test_FindOneCommands(t *testing.T) { err := cl.FindOne(context.Background(), mt.Coll.Name(), bson.D{{}}, &foundDocuments) - assert.NotNil(t, err) + assert.Error(t, err) }) } func Test_UpdateCommands(t *testing.T) { - // Create a connected client using the mock database mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) ctrl := gomock.NewController(t) @@ -311,42 +372,52 @@ func Test_UpdateCommands(t *testing.T) { cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")} metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname", - gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(3) + gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).AnyTimes() - logger.EXPECT().Debug(gomock.Any()).Times(3) + logger.EXPECT().Debug(gomock.Any()).AnyTimes() cl.logger = logger - mt.Run("updateByID", func(mt *mtest.T) { + mt.Run("updateByIDSuccess", func(mt *mtest.T) { cl.Database = mt.DB - mt.AddMockResponses(mtest.CreateSuccessResponse()) - // Create a document to insert + mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "nModified", Value: 1}}) resp, err := cl.UpdateByID(context.Background(), mt.Coll.Name(), "1", bson.M{"$set": bson.M{"name": "test"}}) - assert.NotNil(t, resp) - assert.Nil(t, err) + assert.Equal(t, int64(1), resp) + assert.NoError(t, err) }) - mt.Run("updateOne", func(mt *mtest.T) { + mt.Run("updateByIDError", func(mt *mtest.T) { cl.Database = mt.DB - mt.AddMockResponses(mtest.CreateSuccessResponse()) - // Create a document to insert + mt.AddMockResponses(bson.D{{Key: "ok", Value: 0}}) - err := cl.UpdateOne(context.Background(), mt.Coll.Name(), bson.D{{Key: "name", Value: "test"}}, bson.M{"$set": bson.M{"name": "testing"}}) + resp, err := cl.UpdateByID(context.Background(), mt.Coll.Name(), "1", bson.M{"$set": bson.M{"name": "test"}}) - assert.Nil(t, err) + assert.Equal(t, int64(0), resp) + assert.Error(t, err) }) mt.Run("updateMany", func(mt *mtest.T) { cl.Database = mt.DB - mt.AddMockResponses(mtest.CreateSuccessResponse()) - // Create a document to insert + mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "nModified", Value: 2}}) + + resp, err := cl.UpdateMany(context.Background(), mt.Coll.Name(), bson.D{{Key: "name", Value: "test"}}, + bson.M{"$set": bson.M{"name": "testing"}}) - _, err := cl.UpdateMany(context.Background(), mt.Coll.Name(), bson.D{{Key: "name", Value: "test"}}, + assert.Equal(t, int64(2), resp) + assert.NoError(t, err) + }) + + mt.Run("updateManyError", func(mt *mtest.T) { + cl.Database = mt.DB + mt.AddMockResponses(bson.D{{Key: "ok", Value: 0}}) + + resp, err := cl.UpdateMany(context.Background(), mt.Coll.Name(), bson.D{{Key: "name", Value: "test"}}, bson.M{"$set": bson.M{"name": "testing"}}) - assert.Nil(t, err) + assert.Equal(t, int64(0), resp) + assert.Error(t, err) }) } @@ -381,12 +452,12 @@ func Test_CountDocuments(t *testing.T) { Keys: bson.D{{Key: "x", Value: 1}}, }) - assert.NoError(mt, err, "CreateOne error for index: %v", err) + require.NoError(t, err) resp, err := cl.CountDocuments(context.Background(), mt.Coll.Name(), bson.D{{Key: "name", Value: "test"}}) assert.Equal(t, int64(1), resp) - assert.Nil(t, err) + assert.NoError(t, err) }) } @@ -402,10 +473,10 @@ func Test_DeleteCommands(t *testing.T) { cl := Client{metrics: metrics, tracer: otel.GetTracerProvider().Tracer("gofr-mongo")} - metrics.EXPECT().RecordHistogram(context.Background(), "app_mongo_stats", gomock.Any(), "hostname", - gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).Times(4) + metrics.EXPECT().RecordHistogram(gomock.Any(), "app_mongo_stats", gomock.Any(), "hostname", + gomock.Any(), "database", gomock.Any(), "type", gomock.Any()).AnyTimes() - logger.EXPECT().Debug(gomock.Any()).Times(4) + logger.EXPECT().Debug(gomock.Any()).AnyTimes() cl.logger = logger @@ -416,7 +487,7 @@ func Test_DeleteCommands(t *testing.T) { resp, err := cl.DeleteOne(context.Background(), mt.Coll.Name(), bson.D{{}}) assert.Equal(t, int64(0), resp) - assert.Nil(t, err) + assert.NoError(t, err) }) mt.Run("DeleteOneError", func(mt *mtest.T) { @@ -430,7 +501,7 @@ func Test_DeleteCommands(t *testing.T) { resp, err := cl.DeleteOne(context.Background(), mt.Coll.Name(), bson.D{{}}) assert.Equal(t, int64(0), resp) - assert.NotNil(t, err) + assert.Error(t, err) }) mt.Run("DeleteMany", func(mt *mtest.T) { @@ -440,7 +511,7 @@ func Test_DeleteCommands(t *testing.T) { resp, err := cl.DeleteMany(context.Background(), mt.Coll.Name(), bson.D{{}}) assert.Equal(t, int64(0), resp) - assert.Nil(t, err) + assert.NoError(t, err) }) mt.Run("DeleteManyError", func(mt *mtest.T) { @@ -454,7 +525,7 @@ func Test_DeleteCommands(t *testing.T) { resp, err := cl.DeleteMany(context.Background(), mt.Coll.Name(), bson.D{{}}) assert.Equal(t, int64(0), resp) - assert.NotNil(t, err) + assert.Error(t, err) }) } @@ -483,7 +554,7 @@ func Test_Drop(t *testing.T) { err := cl.Drop(context.Background(), mt.Coll.Name()) - assert.Nil(t, err) + assert.NoError(t, err) }) } @@ -514,14 +585,14 @@ func TestClient_StartSession(t *testing.T) { mt.AddMockResponses(mtest.CreateSuccessResponse()) // Call the StartSession method - sess, err := cl.StartSession() + sess, err := cl.StartSession(context.Background()) ses, ok := sess.(Transaction) if ok { err = ses.StartTransaction() } - assert.Nil(t, err) + require.NoError(t, err) cl.Database = mt.DB mt.AddMockResponses(mtest.CreateSuccessResponse()) @@ -531,16 +602,16 @@ func TestClient_StartSession(t *testing.T) { resp, err := cl.InsertOne(context.Background(), mt.Coll.Name(), doc) assert.NotNil(t, resp) - assert.Nil(t, err) + require.NoError(t, err) err = ses.CommitTransaction(context.Background()) - assert.Nil(t, err) + require.NoError(t, err) ses.EndSession(context.Background()) // Assert that there was no error - assert.Nil(t, err) + assert.NoError(t, err) }) } diff --git a/pkg/gofr/external_db.go b/pkg/gofr/external_db.go index 11706715e..df1c90444 100644 --- a/pkg/gofr/external_db.go +++ b/pkg/gofr/external_db.go @@ -1,6 +1,8 @@ package gofr import ( + "context" + "go.opentelemetry.io/otel" "gofr.dev/pkg/gofr/container" @@ -8,7 +10,7 @@ import ( ) // AddMongo sets the Mongo datasource in the app's container. -func (a *App) AddMongo(db container.MongoProvider) { +func (a *App) AddMongo(ctx context.Context, db container.MongoProvider) error { db.UseLogger(a.Logger()) db.UseMetrics(a.Metrics()) @@ -16,9 +18,13 @@ func (a *App) AddMongo(db container.MongoProvider) { db.UseTracer(tracer) - db.Connect() + if err := db.Connect(ctx); err != nil { + return err + } a.container.Mongo = db + + return nil } // AddFTP sets the FTP datasource in the app's container. @@ -33,16 +39,20 @@ func (a *App) AddFTP(fs file.FileSystemProvider) { } // AddPubSub sets the PubSub client in the app's container. -func (a *App) AddPubSub(pubsub container.PubSubProvider) { +func (a *App) AddPubSub(ctx context.Context, pubsub container.PubSubProvider) error { pubsub.UseLogger(a.Logger()) pubsub.UseMetrics(a.Metrics()) - pubsub.Connect() + if err := pubsub.Connect(ctx); err != nil { + return err + } a.container.PubSub = pubsub + + return nil } -// AddFile sets the FTP,SFTP,S3 datasource in the app's container. +// AddFileStore sets the FTP,SFTP,S3 datasource in the app's container. func (a *App) AddFileStore(fs file.FileSystemProvider) { fs.UseLogger(a.Logger()) fs.UseMetrics(a.Metrics()) @@ -54,7 +64,7 @@ func (a *App) AddFileStore(fs file.FileSystemProvider) { // AddClickhouse initializes the clickhouse client. // Official implementation is available in the package : gofr.dev/pkg/gofr/datasource/clickhouse . -func (a *App) AddClickhouse(db container.ClickhouseProvider) { +func (a *App) AddClickhouse(ctx context.Context, db container.ClickhouseProvider) error { db.UseLogger(a.Logger()) db.UseMetrics(a.Metrics()) @@ -62,9 +72,14 @@ func (a *App) AddClickhouse(db container.ClickhouseProvider) { db.UseTracer(tracer) - db.Connect() + if err := db.Connect(ctx); err != nil { + a.Logger().Error("Failed to connect to Clickhouse", err) + return err + } a.container.Clickhouse = db + + return nil } // UseMongo sets the Mongo datasource in the app's container. @@ -74,7 +89,7 @@ func (a *App) UseMongo(db container.Mongo) { } // AddCassandra sets the Cassandra datasource in the app's container. -func (a *App) AddCassandra(db container.CassandraProvider) { +func (a *App) AddCassandra(ctx context.Context, db container.CassandraProvider) error { db.UseLogger(a.Logger()) db.UseMetrics(a.Metrics()) @@ -82,27 +97,35 @@ func (a *App) AddCassandra(db container.CassandraProvider) { db.UseTracer(tracer) - db.Connect() + if err := db.Connect(ctx); err != nil { + return err + } a.container.Cassandra = db + + return nil } // AddKVStore sets the KV-Store datasource in the app's container. -func (a *App) AddKVStore(db container.KVStoreProvider) { +func (a *App) AddKVStore(ctx context.Context, db container.KVStoreProvider) error { db.UseLogger(a.Logger()) db.UseMetrics(a.Metrics()) + if err := db.Connect(ctx); err != nil { + return err + } + tracer := otel.GetTracerProvider().Tracer("gofr-badger") db.UseTracer(tracer) - db.Connect() - a.container.KVStore = db + + return nil } // AddSolr sets the Solr datasource in the app's container. -func (a *App) AddSolr(db container.SolrProvider) { +func (a *App) AddSolr(ctx context.Context, db container.SolrProvider) error { db.UseLogger(a.Logger()) db.UseMetrics(a.Metrics()) @@ -110,18 +133,26 @@ func (a *App) AddSolr(db container.SolrProvider) { db.UseTracer(tracer) - db.Connect() + if err := db.Connect(ctx); err != nil { + return err + } a.container.Solr = db + + return nil } // AddDgraph sets the Dgraph datasource in the app's container. -func (a *App) AddDgraph(db container.DgraphProvider) { +func (a *App) AddDgraph(ctx context.Context, db container.DgraphProvider) error { // Create the Dgraph client with the provided configuration db.UseLogger(a.Logger()) db.UseMetrics(a.Metrics()) - db.Connect() + if err := db.Connect(ctx); err != nil { + return err + } a.container.DGraph = db + + return nil } diff --git a/pkg/gofr/external_db_test.go b/pkg/gofr/external_db_test.go index 7e2d38bc4..48756e5f8 100644 --- a/pkg/gofr/external_db_test.go +++ b/pkg/gofr/external_db_test.go @@ -1,9 +1,11 @@ package gofr import ( + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" "go.uber.org/mock/gomock" @@ -18,14 +20,17 @@ func TestApp_AddKVStore(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + ctx := context.Background() + mock := container.NewMockKVStoreProvider(ctrl) mock.EXPECT().UseLogger(app.Logger()) mock.EXPECT().UseMetrics(app.Metrics()) + mock.EXPECT().Connect(ctx) mock.EXPECT().UseTracer(otel.GetTracerProvider().Tracer("gofr-badger")) - mock.EXPECT().Connect() - app.AddKVStore(mock) + err := app.AddKVStore(ctx, mock) + require.NoError(t, err) assert.Equal(t, mock, app.container.KVStore) }) @@ -35,6 +40,8 @@ func TestApp_AddMongo(t *testing.T) { t.Run("Adding MongoDB", func(t *testing.T) { app := New() + ctx := context.Background() + ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -43,9 +50,10 @@ func TestApp_AddMongo(t *testing.T) { mock.EXPECT().UseLogger(app.Logger()) mock.EXPECT().UseMetrics(app.Metrics()) mock.EXPECT().UseTracer(gomock.Any()) - mock.EXPECT().Connect() + mock.EXPECT().Connect(ctx) - app.AddMongo(mock) + err := app.AddMongo(ctx, mock) + require.NoError(t, err) assert.Equal(t, mock, app.container.Mongo) }) @@ -55,6 +63,8 @@ func TestApp_AddCassandra(t *testing.T) { t.Run("Adding Cassandra", func(t *testing.T) { app := New() + ctx := context.Background() + ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -63,9 +73,10 @@ func TestApp_AddCassandra(t *testing.T) { mock.EXPECT().UseLogger(app.Logger()) mock.EXPECT().UseMetrics(app.Metrics()) mock.EXPECT().UseTracer(otel.GetTracerProvider().Tracer("gofr-cassandra")) - mock.EXPECT().Connect() + mock.EXPECT().Connect(ctx) - app.AddCassandra(mock) + err := app.AddCassandra(ctx, mock) + require.NoError(t, err) assert.Equal(t, mock, app.container.Cassandra) }) @@ -75,6 +86,8 @@ func TestApp_AddClickhouse(t *testing.T) { t.Run("Adding Clickhouse", func(t *testing.T) { app := New() + ctx := context.Background() + ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -83,9 +96,10 @@ func TestApp_AddClickhouse(t *testing.T) { mock.EXPECT().UseLogger(app.Logger()) mock.EXPECT().UseMetrics(app.Metrics()) mock.EXPECT().UseTracer(otel.GetTracerProvider().Tracer("gofr-clickhouse")) - mock.EXPECT().Connect() + mock.EXPECT().Connect(ctx) - app.AddClickhouse(mock) + err := app.AddClickhouse(ctx, mock) + require.NoError(t, err) assert.Equal(t, mock, app.container.Clickhouse) })