From f41db557f75b99dfc2940af17dc18537d6c409ff Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Tue, 24 Sep 2024 13:37:04 +0300 Subject: [PATCH] Address comments Signed-off-by: nyagamunene --- postgres/certs.go | 21 ++++++++++++++------- postgres/init.go | 2 +- service.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/postgres/certs.go b/postgres/certs.go index 1434607..cab1b81 100644 --- a/postgres/certs.go +++ b/postgres/certs.go @@ -55,7 +55,7 @@ func (repo certsRepo) CreateCert(ctx context.Context, cert certs.Certificate) er // RetrieveLog retrieves computation log from the database. func (repo certsRepo) RetrieveCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { - q := `SELECT * FROM certs WHERE serial_number = $1` + q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE serial_number = $1` var cert certs.Certificate if err := repo.db.QueryRowxContext(ctx, q, serialNumber).StructScan(&cert); err != nil { if err == sql.ErrNoRows { @@ -71,13 +71,13 @@ func (repo certsRepo) GetCAs(ctx context.Context, caType ...certs.CertType) ([]c q := `SELECT serial_number, key, certificate, expiry_time, revoked, type FROM certs WHERE type = ANY($1)` var certificates []certs.Certificate - types := make([]int, 0, len(caType)) + types := make([]string, 0, len(caType)) for i, t := range caType { - types[i] = int(t) + types[i] = t.String() } if len(types) == 0 { - types = []int{int(certs.RootCA), int(certs.IntermediateCA)} + types = []string{certs.RootCA.String(), certs.IntermediateCA.String()} } rows, err := repo.db.QueryContext(ctx, q, types) @@ -86,6 +86,7 @@ func (repo certsRepo) GetCAs(ctx context.Context, caType ...certs.CertType) ([]c } defer rows.Close() + var certType string for rows.Next() { cert := &certs.Certificate{} if err := rows.Scan( @@ -94,11 +95,17 @@ func (repo certsRepo) GetCAs(ctx context.Context, caType ...certs.CertType) ([]c &cert.Certificate, &cert.ExpiryTime, &cert.Revoked, - &cert.Type, + &certType, ); err != nil { return []certs.Certificate{}, errors.Wrap(certs.ErrViewEntity, err) } + crtType, err := certs.CertTypeFromString(certType) + if err != nil { + return []certs.Certificate{}, errors.Wrap(certs.ErrViewEntity, err) + } + cert.Type = crtType + certificates = append(certificates, *cert) } @@ -130,9 +137,9 @@ func (repo certsRepo) ListCerts(ctx context.Context, pm certs.PageMetadata) (cer q := `SELECT serial_number, revoked, expiry_time, entity_id FROM certs %s LIMIT :limit OFFSET :offset` var condition string if pm.EntityID != "" { - condition = `WHERE entity_id = :entity_id AND type = 2` + condition = fmt.Sprintf(`WHERE entity_id = :entity_id AND type = '%s'`, certs.ClientCert.String()) } else { - condition = `WHERE type = 2` + condition = fmt.Sprintf(`WHERE type = '%s'`, certs.ClientCert.String()) } q = fmt.Sprintf(q, condition) var certificates []certs.Certificate diff --git a/postgres/init.go b/postgres/init.go index 2c86abc..0b3f15c 100644 --- a/postgres/init.go +++ b/postgres/init.go @@ -21,7 +21,7 @@ func Migration() *migrate.MemoryMigrationSource { revoked BOOLEAN, expiry_time TIMESTAMP, entity_id VARCHAR(36), - type INT, + type TEXT CHECK (type IN ('RootCA', 'IntermediateCA', 'ClientCert')), PRIMARY KEY (serial_number) )`, }, diff --git a/service.go b/service.go index e905e2b..6f6c1bd 100644 --- a/service.go +++ b/service.go @@ -45,6 +45,39 @@ const ( ClientCert ) +const ( + Root = "RootCA" + Inter = "IntermediateCA" + Client = "ClientCert" + Unknown = "Unknown" +) + +func (c CertType) String() string { + switch c { + case RootCA: + return Root + case IntermediateCA: + return Inter + case ClientCert: + return Client + default: + return Unknown + } +} + +func CertTypeFromString(s string) (CertType, error) { + switch s { + case Root: + return RootCA, nil + case Inter: + return IntermediateCA, nil + case Client: + return ClientCert, nil + default: + return -1, errors.New("unknown cert type") + } +} + type CA struct { Type CertType Certificate *x509.Certificate