diff --git a/pkg/clients/postgres/clients.go b/pkg/clients/postgres/clients.go index 8086616ab..750f47ef6 100644 --- a/pkg/clients/postgres/clients.go +++ b/pkg/clients/postgres/clients.go @@ -154,7 +154,7 @@ func (repo ClientRepository) RetrieveAll(ctx context.Context, pm clients.Page) ( q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.owner_id, '') AS owner_id, c.status, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) - dbPage, err := toDBClientsPage(pm) + dbPage, err := ToDBClientsPage(pm) if err != nil { return clients.ClientsPage{}, errors.Wrap(postgres.ErrFailedToRetrieveAll, err) } @@ -205,7 +205,7 @@ func (repo ClientRepository) RetrieveAllBasicInfo(ctx context.Context, pm client q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) - dbPage, err := toDBClientsPage(pm) + dbPage, err := ToDBClientsPage(pm) if err != nil { return clients.ClientsPage{}, errors.Wrap(postgres.ErrFailedToRetrieveAll, err) } @@ -262,7 +262,7 @@ func (repo ClientRepository) RetrieveAllByIDs(ctx context.Context, pm clients.Pa q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.owner_id, '') AS owner_id, c.status, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) - dbPage, err := toDBClientsPage(pm) + dbPage, err := ToDBClientsPage(pm) if err != nil { return clients.ClientsPage{}, errors.Wrap(postgres.ErrFailedToRetrieveAll, err) } @@ -428,7 +428,7 @@ func ToClient(c DBClient) (clients.Client, error) { }, nil } -func toDBClientsPage(pm clients.Page) (dbClientsPage, error) { +func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) { _, data, err := postgres.CreateMetadataQuery("", pm.Metadata) if err != nil { return dbClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) diff --git a/users/mocks/clients.go b/users/mocks/clients.go index 54b3403c9..e348e1d9a 100644 --- a/users/mocks/clients.go +++ b/users/mocks/clients.go @@ -178,3 +178,9 @@ func (m *Repository) RetrieveAllByIDs(ctx context.Context, pm mgclients.Page) (m return ret.Get(0).(mgclients.ClientsPage), ret.Error(1) } + +func (m *Repository) RetrieveNames(ctx context.Context, pm mgclients.Page) (mgclients.ClientsPage, error) { + ret := m.Called(ctx, pm) + + return ret.Get(0).(mgclients.ClientsPage), ret.Error(1) +} diff --git a/users/postgres/clients.go b/users/postgres/clients.go index 6025bc41b..f9a9bb41b 100644 --- a/users/postgres/clients.go +++ b/users/postgres/clients.go @@ -6,6 +6,8 @@ package postgres import ( "context" "database/sql" + "fmt" + "strings" "github.com/absmach/magistrala/internal/postgres" mgclients "github.com/absmach/magistrala/pkg/clients" @@ -28,6 +30,9 @@ type Repository interface { Save(ctx context.Context, client mgclients.Client) (mgclients.Client, error) CheckSuperAdmin(ctx context.Context, adminID string) error + + // RetrieveNames returns a list of client names that match the given query. + RetrieveNames(ctx context.Context, pm mgclients.Page) (mgclients.ClientsPage, error) } // NewRepository instantiates a PostgreSQL @@ -85,3 +90,77 @@ func (repo clientRepo) CheckSuperAdmin(ctx context.Context, adminID string) erro } return nil } + +func (repo clientRepo) RetrieveNames(ctx context.Context, pm mgclients.Page) (mgclients.ClientsPage, error) { + query := constructQuery(pm) + + q := fmt.Sprintf("SELECT name FROM clients %s LIMIT :limit OFFSET :offset", query) + + dbPage, err := pgclients.ToDBClientsPage(pm) + if err != nil { + return mgclients.ClientsPage{}, errors.Wrap(postgres.ErrFailedToRetrieveAll, err) + } + + rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) + if err != nil { + return mgclients.ClientsPage{}, errors.Wrap(postgres.ErrFailedToRetrieveAll, err) + } + defer rows.Close() + + var items []mgclients.Client + for rows.Next() { + dbc := pgclients.DBClient{} + if err := rows.StructScan(&dbc); err != nil { + return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + c, err := pgclients.ToClient(dbc) + if err != nil { + return mgclients.ClientsPage{}, err + } + + items = append(items, c) + } + cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, query) + + total, err := postgres.Total(ctx, repo.DB, cq, dbPage) + if err != nil { + return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + page := mgclients.ClientsPage{ + Clients: items, + Page: mgclients.Page{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + }, + } + + return page, nil +} + +func constructQuery(pm mgclients.Page) string { + var query []string + var emq string + if pm.Name != "" { + query = append(query, fmt.Sprintf("name ILIKE '%%%s%%'", pm.Name)) + } + if pm.Identity != "" { + query = append(query, fmt.Sprintf("identity ILIKE '%%%s%%'", pm.Identity)) + } + + if len(query) > 0 { + emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + if pm.Order != "" && (pm.Order == "name" || pm.Order == "email" || pm.Order == "created_at" || pm.Order == "updated_at") { + emq = fmt.Sprintf("%s ORDER BY %s", emq, pm.Order) + } + + if pm.Dir != "" && (pm.Dir == "asc" || pm.Dir == "desc") { + emq = fmt.Sprintf("%s %s", emq, pm.Dir) + } + + return emq +} diff --git a/users/postgres/clients_test.go b/users/postgres/clients_test.go index c5cfaaacc..dfaeb93da 100644 --- a/users/postgres/clients_test.go +++ b/users/postgres/clients_test.go @@ -256,3 +256,211 @@ func TestIsPlatformAdmin(t *testing.T) { assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) } } + +func TestClientsRetrieveNames(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM clients") + require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err)) + }) + repo := cpostgres.NewRepository(database) + + nusers := 100 + users := make([]mgclients.Client, nusers) + + name := namesgen.Generate() + + for i := 0; i < nusers; i++ { + client := mgclients.Client{ + ID: testsutil.GenerateUUID(t), + Name: fmt.Sprintf("%s-%d", name, i), + Credentials: mgclients.Credentials{ + Identity: fmt.Sprintf("%s-%d@example.com", name, i), + Secret: password, + }, + Metadata: mgclients.Metadata{}, + Status: mgclients.EnabledStatus, + } + _, err := repo.Save(context.Background(), client) + require.Nil(t, err, fmt.Sprintf("save client unexpected error: %s", err)) + + users[i] = mgclients.Client{ + Name: client.Name, + } + } + + cases := []struct { + desc string + page mgclients.Page + response mgclients.ClientsPage + err error + }{ + { + desc: "retrieve all clients", + page: mgclients.Page{ + Offset: 0, + Limit: uint64(nusers), + }, + response: mgclients.ClientsPage{ + Clients: removeIdentities(users), + Page: mgclients.Page{ + Total: uint64(nusers), + Offset: 0, + Limit: uint64(nusers), + }, + }, + err: nil, + }, + { + desc: "retrieve all clients with offset", + page: mgclients.Page{ + Offset: 10, + Limit: uint64(nusers), + }, + response: mgclients.ClientsPage{ + Clients: removeIdentities(users[10:]), + Page: mgclients.Page{ + Total: uint64(nusers), + Offset: 10, + Limit: uint64(nusers), + }, + }, + err: nil, + }, + { + desc: "retrieve all clients with limit", + page: mgclients.Page{ + Offset: 0, + Limit: 10, + }, + response: mgclients.ClientsPage{ + Clients: removeIdentities(users[:10]), + Page: mgclients.Page{ + Total: uint64(nusers), + Offset: 0, + Limit: 10, + }, + }, + err: nil, + }, + { + desc: "retrieve all clients with offset and limit", + page: mgclients.Page{ + Offset: 10, + Limit: 10, + }, + response: mgclients.ClientsPage{ + Clients: removeIdentities(users[10:20]), + Page: mgclients.Page{ + Total: uint64(nusers), + Offset: 10, + Limit: 10, + }, + }, + err: nil, + }, + { + desc: "retrieve all clients with name", + page: mgclients.Page{ + Name: users[0].Name[:1], + Offset: 0, + Limit: 10, + }, + response: mgclients.ClientsPage{ + Clients: findClient(users, users[0].Name[:1], true, false, 0, 10), + Page: mgclients.Page{ + Total: uint64(nusers), + Offset: 0, + Limit: 10, + }, + }, + err: nil, + }, + { + desc: "retrieve all clients with name", + page: mgclients.Page{ + Name: users[0].Name[:4], + Offset: 0, + Limit: 10, + }, + response: mgclients.ClientsPage{ + Clients: findClient(users, users[0].Name[:4], true, false, 0, 10), + Page: mgclients.Page{ + Total: uint64(nusers), + Offset: 0, + Limit: 10, + }, + }, + err: nil, + }, + { + desc: "retrieve all clients unknown name", + page: mgclients.Page{ + Name: "unknown", + Offset: 0, + Limit: 10, + }, + response: mgclients.ClientsPage{ + Clients: []mgclients.Client(nil), + Page: mgclients.Page{ + Total: 0, + Offset: 0, + Limit: 10, + }, + }, + err: nil, + }, + { + desc: "retrieve all clients unknown identity", + page: mgclients.Page{ + Identity: "unknown", + Offset: 0, + Limit: 10, + }, + response: mgclients.ClientsPage{ + Clients: []mgclients.Client(nil), + Page: mgclients.Page{ + Total: 0, + Offset: 0, + Limit: 10, + }, + }, + err: nil, + }, + } + for _, tc := range cases { + resp, err := repo.RetrieveNames(context.Background(), tc.page) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.response, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, resp)) + } + } +} + +func findClient(clients []mgclients.Client, query string, name, email bool, offset, limit uint64) []mgclients.Client { + clis := []mgclients.Client{} + for _, client := range clients { + if name && strings.Contains(client.Name, query) { + clis = append(clis, client) + } + if email && strings.Contains(client.Credentials.Identity, query) { + clis = append(clis, client) + } + } + + if offset > uint64(len(clis)) { + return []mgclients.Client{} + } + + if limit > uint64(len(clis)) { + return clis[offset:] + } + + return clis[offset:limit] +} + +func removeIdentities(clients []mgclients.Client) []mgclients.Client { + for i := range clients { + clients[i].Credentials.Identity = "" + } + return clients +}