diff --git a/cmd/main.go b/cmd/main.go index 7f264389..2ec3fbb5 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -129,6 +129,7 @@ func main() { logger.Error(fmt.Sprintf("Error in agent service: %s", err)) return } + defer svc.Close() svc = api.LoggingMiddleware(svc, logger) svc = api.MetricsMiddleware( @@ -402,7 +403,7 @@ func StopSignalHandler(ctx context.Context, cancel context.CancelFunc, logger lo shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 5*time.Second) defer shutdownCancel() if err := server.Shutdown(shutdownCtx); err != nil { - return fmt.Errorf("Failed to shutdown %s server: %v", svcName, err) + return fmt.Errorf("failed to shutdown %s server: %v", svcName, err) } return fmt.Errorf("%s service shutdown by signal: %s", svcName, sig) case <-ctx.Done(): diff --git a/pkg/agent/api/endpoints_test.go b/pkg/agent/api/endpoints_test.go index 13a547ec..7c94fefb 100644 --- a/pkg/agent/api/endpoints_test.go +++ b/pkg/agent/api/endpoints_test.go @@ -111,6 +111,9 @@ func TestPublish(t *testing.T) { {"publish data", data, http.StatusOK}, {"publish data with invalid data", "}", http.StatusInternalServerError}, } + t.Cleanup(func() { + assert.Nil(t, svc.Close()) + }) for _, tc := range cases { req := testRequest{ diff --git a/pkg/agent/api/logging.go b/pkg/agent/api/logging.go index 5d989319..2e93f479 100644 --- a/pkg/agent/api/logging.go +++ b/pkg/agent/api/logging.go @@ -119,3 +119,16 @@ func (lm loggingMiddleware) Terminal(uuid, cmdStr string) (err error) { return lm.svc.Terminal(uuid, cmdStr) } + +func (lm loggingMiddleware) Close() (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method close took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors.", message)) + }(time.Now()) + + return lm.svc.Close() +} diff --git a/pkg/agent/api/metrics.go b/pkg/agent/api/metrics.go index 5e58742d..c22d4b21 100644 --- a/pkg/agent/api/metrics.go +++ b/pkg/agent/api/metrics.go @@ -96,9 +96,18 @@ func (ms *metricsMiddleware) Publish(topic, payload string) error { func (ms *metricsMiddleware) Terminal(topic, payload string) error { defer func(begin time.Time) { - ms.counter.With("method", "publish").Add(1) - ms.latency.With("method", "publish").Observe(time.Since(begin).Seconds()) + ms.counter.With("method", "terminal").Add(1) + ms.latency.With("method", "terminal").Observe(time.Since(begin).Seconds()) }(time.Now()) return ms.svc.Terminal(topic, payload) } + +func (ms *metricsMiddleware) Close() error { + defer func(begin time.Time) { + ms.counter.With("method", "close").Add(1) + ms.latency.With("method", "close").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.Close() +} diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 36b495c6..cbb59eb0 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -6,7 +6,6 @@ package agent import ( "crypto/tls" "encoding/json" - "fmt" "os" "time" @@ -14,6 +13,13 @@ import ( "github.com/pelletier/go-toml" ) +var ( + ErrWritingToml = errors.New("error writing to toml file") + errReadingFile = errors.New("error reading config file") + errUnmarshalToml = errors.New("error unmarshaling toml") + errMarshalToml = errors.New("error marshaling toml") +) + type ServerConfig struct { Port string `toml:"port" json:"port"` BrokerURL string `toml:"broker_url" json:"broker_url"` @@ -86,24 +92,24 @@ func NewConfig(sc ServerConfig, cc ChanConfig, ec EdgexConfig, lc LogConfig, mc func SaveConfig(c Config) error { b, err := toml.Marshal(c) if err != nil { - return errors.New(fmt.Sprintf("Error reading config file: %s", err)) + return errors.Wrap(errMarshalToml, err) } if err := os.WriteFile(c.File, b, 0644); err != nil { - return errors.New(fmt.Sprintf("Error writing toml: %s", err)) + return errors.Wrap(ErrWritingToml, err) } return nil } -// Read - retrieve config from a file. +// ReadConfig - retrieve config from a file. func ReadConfig(file string) (Config, error) { data, err := os.ReadFile(file) c := Config{} if err != nil { - return c, errors.New(fmt.Sprintf("Error reading config file: %s", err)) + return Config{}, errors.Wrap(errReadingFile, err) } if err := toml.Unmarshal(data, &c); err != nil { - return Config{}, errors.New(fmt.Sprintf("Error unmarshaling toml: %s", err)) + return Config{}, errors.Wrap(errUnmarshalToml, err) } return c, nil } diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go new file mode 100644 index 00000000..e461f80f --- /dev/null +++ b/pkg/agent/config_test.go @@ -0,0 +1,103 @@ +package agent + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/mainflux/mainflux/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestReadConfig(t *testing.T) { + // Create a temporary config file for testing. + tempFile, err := os.CreateTemp("", "config.toml") + if err != nil { + t.Fatalf("Failed to create temporary file: %v", err) + } + defer os.Remove(tempFile.Name()) + tempFile2, err := os.CreateTemp("", "invalid.toml") + if err != nil { + t.Fatalf("Failed to create temporary file: %v", err) + } + defer os.Remove(tempFile2.Name()) + + sampleConfig := ` + File = "config.toml" + + [channels] + control = "" + data = "" + + [edgex] + url = "http://localhost:48090/api/v1/" + + [heartbeat] + interval = "10s" + + [log] + level = "info" + + [mqtt] + ca_cert = "" + ca_path = "ca.crt" + cert_path = "thing.cert" + client_cert = "" + client_key = "" + mtls = false + password = "" + priv_key_path = "thing.key" + qos = 0 + retain = false + skip_tls_ver = true + url = "localhost:1883" + username = "" + + [server] + nats_url = "nats://127.0.0.1:4222" + port = "9999" + + [terminal] + session_timeout = "1m0s" +` + + if _, writeErr := tempFile.WriteString(sampleConfig); writeErr != nil { + t.Fatalf("Failed to write to temporary file: %v", writeErr) + } + tempFile.Close() + + if _, writeErr := tempFile2.WriteString(strings.ReplaceAll(sampleConfig, "[", "")); writeErr != nil { + t.Fatalf("Failed to write to temporary file: %v", writeErr) + } + tempFile2.Close() + + tests := []struct { + name string + fileName string + expectedErr error + }{ + { + name: "failed to read file", + fileName: "invalidFile.toml", + expectedErr: errReadingFile, + }, + { + name: "invalid toml", + fileName: tempFile2.Name(), + expectedErr: errUnmarshalToml, + }, + { + name: "successful read", + fileName: tempFile.Name(), + expectedErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := ReadConfig(test.fileName) + assert.True(t, errors.Contains(err, test.expectedErr), fmt.Sprintf("expected %v got %v", test.expectedErr, err)) + }) + } +} diff --git a/pkg/agent/heartbeat.go b/pkg/agent/heartbeat.go index 3d9b2efc..0bf3eae3 100644 --- a/pkg/agent/heartbeat.go +++ b/pkg/agent/heartbeat.go @@ -1,6 +1,7 @@ package agent import ( + "context" "sync" "time" ) @@ -33,11 +34,12 @@ type Info struct { type Heartbeat interface { Update() Info() Info + Close() } // interval - duration of interval // if service doesnt send heartbeat during interval it is marked offline. -func NewHeartbeat(name, svcType string, interval time.Duration) Heartbeat { +func NewHeartbeat(ctx context.Context, name, svcType string, interval time.Duration) Heartbeat { ticker := time.NewTicker(interval) s := svc{ info: Info{ @@ -49,13 +51,14 @@ func NewHeartbeat(name, svcType string, interval time.Duration) Heartbeat { ticker: ticker, interval: interval, } - s.listen() + go s.listen(ctx) return &s } -func (s *svc) listen() { - go func() { - for range s.ticker.C { +func (s *svc) listen(ctx context.Context) { + for { + select { + case <-s.ticker.C: // TODO - we can disable ticker when the status gets OFFLINE // and on the next heartbeat enable it again. s.mu.Lock() @@ -63,8 +66,10 @@ func (s *svc) listen() { s.info.Status = offline } s.mu.Unlock() + case <-ctx.Done(): + return } - }() + } } func (s *svc) Update() { @@ -75,5 +80,12 @@ func (s *svc) Update() { } func (s *svc) Info() Info { - return s.info + s.mu.Lock() + defer s.mu.Unlock() + info := s.info + return info +} + +func (s *svc) Close() { + s.ticker.Stop() } diff --git a/pkg/agent/heartbeat_test.go b/pkg/agent/heartbeat_test.go new file mode 100644 index 00000000..f913315c --- /dev/null +++ b/pkg/agent/heartbeat_test.go @@ -0,0 +1,72 @@ +package agent + +import ( + "context" + "testing" + "time" +) + +const ( + name = "TestService" + serviceType = "TestType" + interval = 2 * time.Second +) + +func TestNewHeartbeat(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + heartbeat := NewHeartbeat(ctx, name, serviceType, interval) + + // Check initial status and info + info := heartbeat.Info() + if info.Name != name { + t.Errorf("Expected name to be %s, but got %s", name, info.Name) + } + if info.Type != serviceType { + t.Errorf("Expected type to be %s, but got %s", serviceType, info.Type) + } + if info.Status != online { + t.Errorf("Expected initial status to be %s, but got %s", online, info.Status) + } + t.Cleanup(func() { + cancel() + heartbeat.Close() + }) +} + +func TestHeartbeat_Update(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + heartbeat := NewHeartbeat(ctx, name, serviceType, interval) + + // Sleep for more than the interval to simulate an update + time.Sleep(3 * time.Second) + + heartbeat.Update() + + // Check if the status has been updated to online + info := heartbeat.Info() + if info.Status != online { + t.Errorf("Expected status to be %s, but got %s", online, info.Status) + } + t.Cleanup(func() { + cancel() + heartbeat.Close() + }) +} + +func TestHeartbeat_StatusOffline(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + heartbeat := NewHeartbeat(ctx, name, serviceType, interval) + + // Sleep for more than two intervals to simulate offline status + time.Sleep(5 * time.Second) + + // Check if the status has been updated to offline + info := heartbeat.Info() + if info.Status != offline { + t.Errorf("Expected status to be %s, but got %s", offline, info.Status) + } + t.Cleanup(func() { + cancel() + heartbeat.Close() + }) +} diff --git a/pkg/agent/service.go b/pkg/agent/service.go index 98b445b2..8506c6be 100644 --- a/pkg/agent/service.go +++ b/pkg/agent/service.go @@ -107,6 +107,9 @@ type Service interface { // Publish message. Publish(string, string) error + + // Closes all connections. + Close() error } var _ Service = (*agent)(nil) @@ -121,13 +124,21 @@ type agent struct { terminals map[string]terminal.Session } +func (ag *agent) Close() error { + ag.mqttClient.Disconnect(1) + for _, svc := range ag.svcs { + svc.Close() + } + return ag.broker.Close() +} + func (ag *agent) handle(ctx context.Context, pub messaging.Publisher, logger log.Logger, cfg HeartbeatConfig) handleFunc { return func(msg *messaging.Message) error { sub := msg.Channel tok := strings.Split(sub, ".") if len(tok) < 3 { - ag.logger.Error(fmt.Sprintf("Failed: Subject has incorrect length %s", sub)) - return fmt.Errorf("Failed: Subject has incorrect length %s", sub) + ag.logger.Error(fmt.Sprintf("failed: subject has incorrect length %s", sub)) + return fmt.Errorf("failed: subject has incorrect length %s", sub) } svcname := tok[1] svctype := tok[2] @@ -135,7 +146,7 @@ func (ag *agent) handle(ctx context.Context, pub messaging.Publisher, logger log // if there is multiple instances of the same service // we will have to add another distinction. if _, ok := ag.svcs[svcname]; !ok { - svc := NewHeartbeat(svcname, svctype, cfg.Interval) + svc := NewHeartbeat(ctx, svcname, svctype, cfg.Interval) ag.svcs[svcname] = svc ag.logger.Info(fmt.Sprintf("Services '%s-%s' registered", svcname, svctype)) } diff --git a/pkg/bootstrap/bootstrap.go b/pkg/bootstrap/bootstrap.go index a24ceef2..a805fdb5 100644 --- a/pkg/bootstrap/bootstrap.go +++ b/pkg/bootstrap/bootstrap.go @@ -25,6 +25,11 @@ import ( const exportConfigFile = "/configs/export/config.toml" +var ( + errInvalidBootstrapRetriesValue = errors.New("invalid BOOTSTRAP_RETRIES value") + errInvalidBootstrapRetryDelay = errors.New("invalid BOOTSTRAP_RETRY_DELAY_SECONDS value") +) + // Config represents the parameters for bootstrapping. type Config struct { URL string @@ -46,20 +51,20 @@ type ConfigContent struct { } type deviceConfig struct { - MainfluxID string `json:"mainflux_id"` - MainfluxKey string `json:"mainflux_key"` - MainfluxChannels []bootstrap.Channel `json:"mainflux_channels"` - ClientKey string `json:"client_key"` - ClientCert string `json:"client_cert"` - CaCert string `json:"ca_cert"` - SvcsConf ServicesConfig `json:"-"` + ThingID string `json:"thing_id"` + ThingKey string `json:"thing_key"` + Channels []bootstrap.Channel `json:"channels"` + ClientKey string `json:"client_key"` + ClientCert string `json:"client_cert"` + CaCert string `json:"ca_cert"` + SvcsConf ServicesConfig `json:"-"` } // Bootstrap - Retrieve device config. func Bootstrap(cfg Config, logger log.Logger, file string) error { retries, err := strconv.ParseUint(cfg.Retries, 10, 64) if err != nil { - return errors.New(fmt.Sprintf("Invalid BOOTSTRAP_RETRIES value: %s", err)) + return errors.Wrap(errInvalidBootstrapRetriesValue, err) } if retries == 0 { @@ -69,7 +74,7 @@ func Bootstrap(cfg Config, logger log.Logger, file string) error { retryDelaySec, err := strconv.ParseUint(cfg.RetryDelaySec, 10, 64) if err != nil { - return errors.New(fmt.Sprintf("Invalid BOOTSTRAP_RETRY_DELAY_SECONDS value: %s", err)) + return errors.Wrap(errInvalidBootstrapRetryDelay, err) } logger.Info(fmt.Sprintf("Requesting config for %s from %s", cfg.ID, cfg.URL)) @@ -91,15 +96,15 @@ func Bootstrap(cfg Config, logger log.Logger, file string) error { } } - if len(dc.MainfluxChannels) < 2 { + if len(dc.Channels) < 2 { return agent.ErrMalformedEntity } - ctrlChan := dc.MainfluxChannels[0].ID - dataChan := dc.MainfluxChannels[1].ID - if dc.MainfluxChannels[0].Metadata["type"] == "data" { - ctrlChan = dc.MainfluxChannels[1].ID - dataChan = dc.MainfluxChannels[0].ID + ctrlChan := dc.Channels[0].ID + dataChan := dc.Channels[1].ID + if dc.Channels[0].Metadata["type"] == "data" { + ctrlChan = dc.Channels[1].ID + dataChan = dc.Channels[0].ID } sc := dc.SvcsConf.Agent.Server @@ -111,8 +116,8 @@ func Bootstrap(cfg Config, logger log.Logger, file string) error { lc := dc.SvcsConf.Agent.Log mc := dc.SvcsConf.Agent.MQTT - mc.Password = dc.MainfluxKey - mc.Username = dc.MainfluxID + mc.Password = dc.ThingKey + mc.Username = dc.ThingID mc.ClientCert = dc.ClientCert mc.ClientKey = dc.ClientKey mc.CaCert = dc.CaCert @@ -215,7 +220,6 @@ func getConfig(bsID, bsKey, bsSvrURL string, skipTLS bool, logger log.Logger) (d if err := json.Unmarshal([]byte(body), &h); err != nil { return deviceConfig{}, err } - fmt.Println(h.Content) sc := ServicesConfig{} if err := json.Unmarshal([]byte(h.Content), &sc); err != nil { return deviceConfig{}, err diff --git a/pkg/bootstrap/bootstrap_test.go b/pkg/bootstrap/bootstrap_test.go new file mode 100644 index 00000000..b94c0373 --- /dev/null +++ b/pkg/bootstrap/bootstrap_test.go @@ -0,0 +1,137 @@ +package bootstrap + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/mainflux/agent/pkg/agent" + "github.com/mainflux/mainflux/logger" + "github.com/mainflux/mainflux/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestBootstrap(t *testing.T) { + // Create a mock HTTP server to handle requests from the getConfig function. + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Thing mockKey" && r.Header.Get("Authorization") != "Thing invalidChannels" { + http.Error(w, "Invalid authorization header", http.StatusUnauthorized) + return + } + resp := ` + { + "thing_id": "e22c383a-d2ab-47c1-89cd-903955da993d", + "thing_key": "fc987711-1828-461b-aa4b-16d5b2c642fe", + "channels": [ + %s + ], + "content": "{\"agent\":{\"edgex\":{\"url\":\"http://localhost:48090/api/v1/\"},\"heartbeat\":{\"interval\":\"30s\"},\"log\":{\"level\":\"debug\"},\"mqtt\":{\"mtls\":false,\"qos\":0,\"retain\":false,\"skip_tls_ver\":true,\"url\":\"tcp://mainflux-domain.com:1883\"},\"server\":{\"nats_url\":\"localhost:4222\",\"port\":\"9000\"},\"terminal\":{\"session_timeout\":\"30s\"}},\"export\":{\"exp\":{\"cache_db\":\"0\",\"cache_pass\":\"\",\"cache_url\":\"localhost:6379\",\"log_level\":\"debug\",\"nats\":\"nats://localhost:4222\",\"port\":\"8172\"},\"mqtt\":{\"ca_path\":\"ca.crt\",\"cert_path\":\"thing.crt\",\"channel\":\"\",\"host\":\"tcp://mainflux-domain.com:1883\",\"mtls\":false,\"password\":\"\",\"priv_key_path\":\"thing.key\",\"qos\":0,\"retain\":false,\"skip_tls_ver\":false,\"username\":\"\"},\"routes\":[{\"mqtt_topic\":\"\",\"nats_topic\":\"channels\",\"subtopic\":\"\",\"type\":\"mfx\",\"workers\":10},{\"mqtt_topic\":\"\",\"nats_topic\":\"export\",\"subtopic\":\"\",\"type\":\"default\",\"workers\":10}]}}" + } + ` + if r.Header.Get("Authorization") == "Thing invalidChannels" { + // Simulate a malformed response. + channels := ` + { + "id": "fa5f9ba8-a1fc-4380-9edb-d0c23eaa24ec", + "name": "control-channel", + "metadata": { + "type": "control" + } + } + ` + resp = fmt.Sprintf(resp, channels) + w.WriteHeader(http.StatusOK) + if _, err := io.WriteString(w, resp); err != nil { + t.Errorf(err.Error()) + } + return + } + // Simulate a successful response. + channels := ` + { + "id": "fa5f9ba8-a1fc-4380-9edb-d0c23eaa24ec", + "name": "control-channel", + "metadata": { + "type": "control" + } + }, + { + "id": "24e5473e-3cbe-43d9-8a8b-a725ff918c0e", + "name": "data-channel", + "metadata": { + "type": "data" + } + }, + { + "id": "1eac45c2-0f72-4089-b255-ebd2e5732bbb", + "name": "export-channel", + "metadata": { + "type": "export" + } + } + ` + resp = fmt.Sprintf(resp, channels) + w.WriteHeader(http.StatusOK) + if _, err := io.WriteString(w, resp); err != nil { + t.Errorf(err.Error()) + } + })) + defer mockServer.Close() + mockLogger := logger.NewMock() + tests := []struct { + name string + config Config + file string + expectedErr error + }{ + { + name: "invalid retries type", + config: Config{Retries: "invalid"}, + expectedErr: errInvalidBootstrapRetriesValue, + }, + { + name: "zero retires", + config: Config{Retries: "0"}, + expectedErr: nil, + }, + { + name: "invalid retry delay", + config: Config{Retries: "1", RetryDelaySec: "e"}, + expectedErr: errInvalidBootstrapRetryDelay, + }, + { + name: "authorization error", + config: Config{Retries: "1", RetryDelaySec: "1", URL: mockServer.URL, Key: "wrongKey"}, + expectedErr: nil, + }, + { + name: "malformed channels", + config: Config{Retries: "1", RetryDelaySec: "1", URL: mockServer.URL, Key: "invalidChannels"}, + expectedErr: errors.ErrMalformedEntity, + }, + { + name: "successful configuration", + config: Config{Retries: "1", RetryDelaySec: "1", URL: mockServer.URL, Key: "mockKey"}, + expectedErr: agent.ErrWritingToml, + }, + { + name: "successful configuration", + config: Config{Retries: "1", RetryDelaySec: "1", URL: mockServer.URL, Key: "mockKey"}, + expectedErr: nil, + file: "config.toml", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := Bootstrap(test.config, mockLogger, test.file) + assert.True(t, errors.Contains(err, test.expectedErr), fmt.Sprintf("expected %v got %v", test.expectedErr, err)) + }) + } + t.Cleanup(func() { + os.Remove("config.toml") + }) +} diff --git a/pkg/conn/conn_test.go b/pkg/conn/conn_test.go new file mode 100644 index 00000000..8ac25451 --- /dev/null +++ b/pkg/conn/conn_test.go @@ -0,0 +1,137 @@ +package conn + +import ( + "context" + "testing" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/mainflux/agent/pkg/agent" + "github.com/mainflux/mainflux/logger" + "github.com/mainflux/mainflux/pkg/messaging" + "github.com/stretchr/testify/assert" +) + +// Mocks for testing. +type mockService struct{} + +func (m *mockService) Config() agent.Config { return agent.Config{} } +func (m *mockService) Services() []agent.Info { return []agent.Info{} } +func (m *mockService) Publish(string, string) error { return nil } +func (m *mockService) AddConfig(agent.Config) error { return nil } +func (m *mockService) Control(uuid, command string) error { return nil } +func (m *mockService) Execute(uuid, command string) (string, error) { return "", nil } +func (m *mockService) ServiceConfig(ctx context.Context, uuid, command string) error { + return nil +} +func (m *mockService) Close() error { return nil } +func (m *mockService) Terminal(uuid, command string) error { return nil } + +type mockMQTTClient struct { + subscribeErr error + waitErr error +} + +// AddRoute implements mqtt.Client. +func (*mockMQTTClient) AddRoute(topic string, callback mqtt.MessageHandler) { + panic("unimplemented") +} + +// Connect implements mqtt.Client. +func (*mockMQTTClient) Connect() mqtt.Token { + panic("unimplemented") +} + +// Disconnect implements mqtt.Client. +func (*mockMQTTClient) Disconnect(quiesce uint) { + panic("unimplemented") +} + +// IsConnected implements mqtt.Client. +func (*mockMQTTClient) IsConnected() bool { + panic("unimplemented") +} + +// IsConnectionOpen implements mqtt.Client. +func (*mockMQTTClient) IsConnectionOpen() bool { + panic("unimplemented") +} + +// OptionsReader implements mqtt.Client. +func (*mockMQTTClient) OptionsReader() mqtt.ClientOptionsReader { + panic("unimplemented") +} + +// Publish implements mqtt.Client. +func (*mockMQTTClient) Publish(topic string, qos byte, retained bool, payload interface{}) mqtt.Token { + panic("unimplemented") +} + +// SubscribeMultiple implements mqtt.Client. +func (*mockMQTTClient) SubscribeMultiple(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + panic("unimplemented") +} + +// Unsubscribe implements mqtt.Client. +func (*mockMQTTClient) Unsubscribe(topics ...string) mqtt.Token { + panic("unimplemented") +} + +func (m *mockMQTTClient) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token { + return &mockToken{err: m.subscribeErr} +} + +func (m *mockMQTTClient) Wait() bool { + return m.waitErr == nil +} + +type mockToken struct { + err error +} + +func (m *mockToken) Wait() bool { return true } +func (m *mockToken) WaitTimeout(time.Duration) bool { return true } +func (m *mockToken) Error() error { return m.err } +func (m *mockToken) Done() <-chan struct{} { + x := make(chan struct{}) + return x +} + +type mockMessageBroker struct { + publishErr error +} + +// Close implements messaging.PubSub. +func (*mockMessageBroker) Close() error { + panic("unimplemented") +} + +// Subscribe implements messaging.PubSub. +func (*mockMessageBroker) Subscribe(ctx context.Context, id string, topic string, handler messaging.MessageHandler) error { + panic("unimplemented") +} + +// Unsubscribe implements messaging.PubSub. +func (*mockMessageBroker) Unsubscribe(ctx context.Context, id string, topic string) error { + panic("unimplemented") +} + +func (m *mockMessageBroker) Publish(ctx context.Context, topic string, msg *messaging.Message) error { + return m.publishErr +} + +func TestBroker_Subscribe(t *testing.T) { + svc := &mockService{} + client := &mockMQTTClient{} + chann := "test" + messBroker := &mockMessageBroker{} + + broker := NewBroker(svc, client, chann, messBroker, logger.NewMock()) + + assert.NotNil(t, broker) + + ctx := context.Background() + err := broker.Subscribe(ctx) + + assert.NoError(t, err) +} diff --git a/pkg/edgex/client_test.go b/pkg/edgex/client_test.go new file mode 100644 index 00000000..51750427 --- /dev/null +++ b/pkg/edgex/client_test.go @@ -0,0 +1,138 @@ +package edgex + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mainflux/mainflux/logger" +) + +const expectedResponse = "Response" + +func TestPushOperation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("Expected POST request, got %s", r.Method) + } + + expectedURL := "/operation" + if r.URL.String() != expectedURL { + t.Errorf("Expected URL %s, got %s", expectedURL, r.URL.String()) + } + + expectedBody := `{"action":"start","services":["service1","service2"]}` + bodyBytes, _ := io.ReadAll(r.Body) + if string(bodyBytes) != expectedBody { + t.Errorf("Expected request body %s, got %s", expectedBody, string(bodyBytes)) + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(expectedResponse)); err != nil { + t.Errorf("error writing response %v", err) + } + })) + defer server.Close() + + client := NewClient(server.URL+"/", logger.NewMock()) + + response, err := client.PushOperation([]string{"start", "service1", "service2"}) + if err != nil { + t.Errorf("Error calling PushOperation: %v", err) + } + + if response != expectedResponse { + t.Errorf("Expected response %s, got %s", expectedResponse, response) + } +} + +func TestFetchConfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("Expected GET request, got %s", r.Method) + } + + expectedURL := "/config/start,service1,service2" + if r.URL.String() != expectedURL { + t.Errorf("Expected URL %s, got %s", expectedURL, r.URL.String()) + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(expectedResponse)); err != nil { + t.Errorf("error writing response %v", err) + } + })) + defer server.Close() + + client := NewClient(server.URL+"/", logger.NewMock()) + + response, err := client.FetchConfig([]string{"start", "service1", "service2"}) + if err != nil { + t.Errorf("Error calling FetchConfig: %v", err) + } + + if response != expectedResponse { + t.Errorf("Expected response %s, got %s", expectedResponse, response) + } +} + +func TestFetchMetrics(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("Expected GET request, got %s", r.Method) + } + + expectedURL := "/metrics/start,service1,service2" + if r.URL.String() != expectedURL { + t.Errorf("Expected URL %s, got %s", expectedURL, r.URL.String()) + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(expectedResponse)); err != nil { + t.Errorf("error writing response %v", err) + } + })) + defer server.Close() + + client := NewClient(server.URL+"/", logger.NewMock()) + + response, err := client.FetchMetrics([]string{"start", "service1", "service2"}) + if err != nil { + t.Errorf("Error calling FetchMetrics: %v", err) + } + + if response != expectedResponse { + t.Errorf("Expected response %s, got %s", expectedResponse, response) + } +} + +func TestPing(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("Expected GET request, got %s", r.Method) + } + + expectedURL := "/ping" + if r.URL.String() != expectedURL { + t.Errorf("Expected URL %s, got %s", expectedURL, r.URL.String()) + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(expectedResponse)); err != nil { + t.Errorf("error writing response %v", err) + } + })) + defer server.Close() + + client := NewClient(server.URL+"/", logger.NewMock()) + + response, err := client.Ping() + if err != nil { + t.Errorf("Error calling Ping: %v", err) + } + + if response != expectedResponse { + t.Errorf("Expected response %s, got %s", expectedResponse, response) + } +} diff --git a/pkg/terminal/terminal.go b/pkg/terminal/terminal.go index 6987a3ea..0a04207c 100644 --- a/pkg/terminal/terminal.go +++ b/pkg/terminal/terminal.go @@ -13,7 +13,6 @@ import ( "github.com/mainflux/agent/pkg/encoder" "github.com/mainflux/mainflux/logger" - "github.com/mainflux/mainflux/pkg/errors" ) const ( @@ -54,7 +53,7 @@ func NewSession(uuid string, timeout time.Duration, publish func(channel, payloa c := exec.Command("bash") ptmx, err := pty.Start(c) if err != nil { - return t, errors.New(err.Error()) + return t, err } t.ptmx = ptmx @@ -120,7 +119,7 @@ func (t *term) Send(p []byte) error { nr, err := io.Copy(t.ptmx, in) t.logger.Debug(fmt.Sprintf("Written to ptmx: %d", nr)) if err != nil { - return errors.New(err.Error()) + return err } return nil } diff --git a/pkg/terminal/terminal_test.go b/pkg/terminal/terminal_test.go new file mode 100644 index 00000000..0b7bc636 --- /dev/null +++ b/pkg/terminal/terminal_test.go @@ -0,0 +1,84 @@ +package terminal + +import ( + "errors" + "testing" + "time" + + "github.com/mainflux/mainflux/logger" + "github.com/stretchr/testify/assert" +) + +const ( + uuid = "test-uuid" + timeout = 5 * time.Second +) + +// MockPublish is a mock function for the publish function used in NewSession. +func mockPublish(channel, payload string) error { + return nil +} + +func mockPublishFail(channel, payload string) error { + return errors.New("") +} + +func TestWrite(t *testing.T) { + t.Run("successful publish", func(t *testing.T) { + session, err := NewSession(uuid, timeout, mockPublish, logger.NewMock()) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + + // Simulate writing data to the session + data := []byte("test data") + n, err := session.Write(data) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + assert.Equal(t, len(data), n) + }) + t.Run("failed publish", func(t *testing.T) { + session, err := NewSession(uuid, timeout, mockPublishFail, logger.NewMock()) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + + // Simulate writing data to the session + data := []byte("test data") + _, err = session.Write(data) + assert.NotNil(t, err) + }) +} + +func TestSend(t *testing.T) { + session, err := NewSession(uuid, timeout, mockPublish, logger.NewMock()) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + + // Simulate sending data to the session + data := []byte("test data") + + if err = session.Send(data); err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + +} + +func TestIsDone(t *testing.T) { + publish := mockPublish + + session, err := NewSession(uuid, timeout, publish, logger.NewMock()) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + + // Wait for the "done" channel to be closed or for a timeout, and perform assertions accordingly. + select { + case <-session.IsDone(): + // Session is done as expected. + case <-time.After(10 * time.Second): + t.Fatalf("Expected session to be done, but it is still running.") + } +}