Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tlv2 comments #252

Merged
merged 14 commits into from
Oct 26, 2022
Merged
6 changes: 4 additions & 2 deletions kafka/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error
failureCount++
continue
}
logger.Info().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset)
logger.Info().Msgf("Reader Stats: %#v", consumer.Stats())
logger.Debug().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset)
logger.Debug().Msgf("Reader Stats: %#v", consumer.Stats())
logger.Debug().Msgf("topicMappings: %+v", topicMappings)
for _, topicMapping := range topicMappings {
logger.Debug().Msgf("topic: %+v, topicMapping: %+v", msg.Topic, topicMapping.Topic)
if msg.Topic == topicMapping.Topic {
go func(
msg kafka.Message,
Expand Down
49 changes: 42 additions & 7 deletions kafka/signed_blinded_token_issuer_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@ func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server
issuerError = 2
)

log.Debug().Msg("starting blinded token processor")

log.Info().Msg("deserialize signing request")

blindedTokenRequestSet, err := avroSchema.DeserializeSigningRequestSet(bytes.NewReader(data))
if err != nil {
return fmt.Errorf("request %s: failed avro deserialization: %w", blindedTokenRequestSet.Request_id, err)
}

logger := log.With().Str("request_id", blindedTokenRequestSet.Request_id).Logger()

logger.Debug().Msg("processing blinded token request for request_id")

var blindedTokenResults []avroSchema.SigningResultV2
if len(blindedTokenRequestSet.Data) > 1 {
// NOTE: When we start supporting multiple requests we will need to review
Expand All @@ -44,6 +50,7 @@ func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server

OUTER:
for _, request := range blindedTokenRequestSet.Data {
logger.Debug().Msgf("processing request: %+v", request)
if request.Blinded_tokens == nil {
logger.Error().Err(errors.New("blinded tokens is empty")).Msg("")
blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{
Expand All @@ -56,6 +63,7 @@ OUTER:
}

// check to see if issuer cohort will overflow
logger.Debug().Msgf("checking request cohort: %+v", request)
if request.Issuer_cohort > math.MaxInt16 || request.Issuer_cohort < math.MinInt16 {
logger.Error().Msg("invalid cohort")
blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{
Expand All @@ -67,6 +75,7 @@ OUTER:
break OUTER
}

logger.Debug().Msgf("getting latest issuer: %+v - %+v", request.Issuer_type, request.Issuer_cohort)
issuer, appErr := server.GetLatestIssuer(request.Issuer_type, int16(request.Issuer_cohort))
if appErr != nil {
logger.Error().Err(appErr).Msg("error retrieving issuer")
Expand All @@ -79,6 +88,7 @@ OUTER:
break OUTER
}

logger.Debug().Msgf("checking if issuer is version 3: %+v", issuer)
// if this is a time aware issuer, make sure the request contains the appropriate number of blinded tokens
if issuer.Version == 3 && issuer.Buffer > 0 {
if len(request.Blinded_tokens)%(issuer.Buffer+issuer.Overlap) != 0 {
Expand All @@ -93,10 +103,12 @@ OUTER:
}
}

logger.Debug().Msgf("checking blinded tokens: %+v", request.Blinded_tokens)
var blindedTokens []*crypto.BlindedToken
// Iterate over the provided tokens and create data structure from them,
// grouping into a slice for approval
for _, stringBlindedToken := range request.Blinded_tokens {
logger.Debug().Msgf("blinded token: %+v", stringBlindedToken)
blindedToken := crypto.BlindedToken{}
err := blindedToken.UnmarshalText([]byte(stringBlindedToken))
if err != nil {
Expand All @@ -113,24 +125,38 @@ OUTER:
blindedTokens = append(blindedTokens, &blindedToken)
}

logger.Debug().Msgf("checking if issuer is time aware: %+v - %+v", issuer.Version, issuer.Buffer)
// if the issuer is time aware, we need to approve tokens
if issuer.Version == 3 && issuer.Buffer > 0 {
// number of tokens per signing key
// Calculate the number of tokens per signing key.
// Given the mod check this should be a multiple of the total tokens in the request.
var numT = len(request.Blinded_tokens) / (issuer.Buffer + issuer.Overlap)
// sign tokens with all the keys in buffer+overlap
for i := issuer.Buffer + issuer.Overlap; i > 0; i-- {
count := 0
for i := 0; i < len(blindedTokens); i += numT {
count++

logger.Debug().Msgf("version 3 issuer: %+v , numT: %+v", issuer, numT)
var (
blindedTokensSlice []*crypto.BlindedToken
signingKey *crypto.SigningKey
validFrom string
validTo string
)

signingKey = issuer.Keys[len(issuer.Keys)-i].SigningKey
validFrom = issuer.Keys[len(issuer.Keys)-i].StartAt.Format(time.RFC3339)
validTo = issuer.Keys[len(issuer.Keys)-i].EndAt.Format(time.RFC3339)
signingKey = issuer.Keys[len(issuer.Keys)-count].SigningKey
validFrom = issuer.Keys[len(issuer.Keys)-count].StartAt.Format(time.RFC3339)
validTo = issuer.Keys[len(issuer.Keys)-count].EndAt.Format(time.RFC3339)

blindedTokensSlice = blindedTokens[(i - numT):i]
// Calculate the next step size to retrieve. Given previous checks end should never
// be greater than the total number of tokens.
end := i + numT
if end > len(blindedTokens) {
return fmt.Errorf("request %s: error invalid token step length",
blindedTokenRequestSet.Request_id)
}

// Get the next group of tokens and approve
blindedTokensSlice = blindedTokens[i:end]
signedTokens, DLEQProof, err := btd.ApproveTokens(blindedTokensSlice, signingKey)
if err != nil {
// @TODO: If one token fails they will all fail. Assess this behavior
Expand All @@ -145,6 +171,8 @@ OUTER:
break OUTER
}

logger.Debug().Msg("marshalling proof")

marshaledDLEQProof, err := DLEQProof.MarshalText()
if err != nil {
return fmt.Errorf("request %s: could not marshal dleq proof: %w", blindedTokenRequestSet.Request_id, err)
Expand All @@ -170,6 +198,7 @@ OUTER:
marshaledSignedTokens = append(marshaledSignedTokens, string(marshaledToken[:]))
}

logger.Debug().Msg("getting public key")
publicKey := signingKey.PublicKey()
marshaledPublicKey, err := publicKey.MarshalText()
if err != nil {
Expand All @@ -195,6 +224,7 @@ OUTER:
signingKey = issuer.Keys[len(issuer.Keys)-1].SigningKey
}

logger.Debug().Msgf("approving tokens: %+v", blindedTokens)
// @TODO: If one token fails they will all fail. Assess this behavior
signedTokens, DLEQProof, err := btd.ApproveTokens(blindedTokens, signingKey)
if err != nil {
Expand Down Expand Up @@ -256,6 +286,7 @@ OUTER:
Request_id: blindedTokenRequestSet.Request_id,
Data: blindedTokenResults,
}
logger.Debug().Msgf("resultSet: %+v", resultSet)

var resultSetBuffer bytes.Buffer
err = resultSet.Serialize(&resultSetBuffer)
Expand All @@ -264,11 +295,15 @@ OUTER:
blindedTokenRequestSet.Request_id, resultSetBuffer.String(), err)
}

logger.Debug().Msg("ending blinded token request processor loop")
logger.Debug().Msgf("about to emit: %+v", resultSet)
err = Emit(producer, resultSetBuffer.Bytes(), log)
if err != nil {
logger.Error().Msgf("failed to emit: %+v", resultSet)
return fmt.Errorf("request %s: failed to emit results to topic %s: %w",
blindedTokenRequestSet.Request_id, producer.Topic, err)
}
logger.Debug().Msgf("emitted: %+v", resultSet)

return nil
}
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func main() {
if os.Getenv("ENV") != "production" {
zerolog.SetGlobalLevel(zerolog.TraceLevel)
}
zerolog.SetGlobalLevel(zerolog.TraceLevel)

srv := *server.DefaultServer

Expand Down
9 changes: 9 additions & 0 deletions server/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,14 @@ func (c *Server) SetupCronTasks() {
}); err != nil {
panic(err)
}
if _, err := cron.AddFunc(cadence, func() {
rows, err := c.deleteIssuerKeys("P1M")
if err != nil {
panic(err)
}
c.Logger.Infof("cron: delete issuers keys removed %d", rows)
}); err != nil {
panic(err)
}
cron.Start()
}
106 changes: 87 additions & 19 deletions server/db.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -118,7 +119,7 @@ type RedemptionV2 struct {
TTL int64 `json:"TTL"`
}

// CacheInterface cach functions
// CacheInterface cache functions
type CacheInterface interface {
Get(k string) (interface{}, bool)
Delete(k string)
Expand Down Expand Up @@ -397,6 +398,56 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) (*[
return &issuers, nil
}

func (c *Server) fetchIssuerByType(ctx context.Context, issuerType string) (*Issuer, error) {
if c.caches != nil {
if cached, found := c.caches["issuer"].Get(issuerType); found {
// TODO: check this
return cached.(*Issuer), nil
}
}

var issuerV3 issuer
err := c.db.GetContext(ctx, &issuerV3,
`SELECT *
FROM v3_issuers
WHERE issuer_type=$1
ORDER BY expires_at DESC NULLS LAST, created_at DESC`, issuerType)
if err != nil {
return nil, err
}

convertedIssuer, err := c.convertDBIssuer(issuerV3)
if err != nil {
return nil, err
}

if convertedIssuer.Keys == nil {
convertedIssuer.Keys = []IssuerKeys{}
}

var fetchIssuerKeys []issuerKeys
err = c.db.SelectContext(ctx, &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1
ORDER BY end_at DESC NULLS LAST, start_at DESC`, issuerV3.ID)
if err != nil {
return nil, err
}

for _, v := range fetchIssuerKeys {
k, err := c.convertDBIssuerKeys(v)
if err != nil {
c.Logger.Error("Failed to convert issuer keys from DB")
return nil, err
}
convertedIssuer.Keys = append(convertedIssuer.Keys, *k)
}

if c.caches != nil {
c.caches["issuer"].SetDefault(issuerType, issuerV3)
}

return convertedIssuer, nil
}

func (c *Server) fetchIssuers(issuerType string) (*[]Issuer, error) {
if c.caches != nil {
if cached, found := c.caches["issuers"].Get(issuerType); found {
Expand Down Expand Up @@ -564,14 +615,14 @@ func (c *Server) rotateIssuers() error {
err = tx.Commit()
}()

fetchedIssuers := []issuer{}
var fetchedIssuers []issuer
err = tx.Select(
&fetchedIssuers,
`SELECT * FROM v3_issuers
WHERE expires_at IS NOT NULL
AND last_rotated_at < NOW() - $1 * INTERVAL '1 day'
AND expires_at < NOW() + $1 * INTERVAL '1 day'
AND version >= 2
AND version <= 2
FOR UPDATE SKIP LOCKED`, cfg.DefaultDaysBeforeExpiry,
)
if err != nil {
Expand Down Expand Up @@ -619,26 +670,24 @@ func (c *Server) rotateIssuersV3() error {

fetchedIssuers := []issuer{}

// we need to get all of the v3 issuers that
// 1. are not expired
// we need to get all the v3 issuers that are
// 1. not expired
// 2. now is after valid_from
// 3. have max(issuer_v3.end_at) < buffer

err = tx.Select(
&fetchedIssuers,
`
select
i.issuer_id, i.issuer_type, i.issuer_cohort, i.max_tokens, i.version,
i.buffer, i.valid_from, i.last_rotated_at, i.expires_at, i.duration,
i.created_at
from
v3_issuers i
join v3_issuer_keys ik on (ik.issuer_id = i.issuer_id)
where
i.version = 3
and i.expires_at is not null and i.expires_at < now()
and greatest(ik.end_at) < now() + i.buffer * i.duration::interval
for update skip locked
select
i.issuer_id, i.issuer_type, i.issuer_cohort, i.max_tokens, i.version,i.buffer, i.valid_from, i.last_rotated_at, i.expires_at, i.duration,i.created_at
from
v3_issuers i
where
i.version = 3 and
i.expires_at is not null and
i.expires_at < now()
and (select max(end_at) from v3_issuer_keys where issuer_id=i.issuer_id) < now() + i.buffer * i.duration::interval
for update skip locked
`,
)
if err != nil {
Expand Down Expand Up @@ -669,6 +718,21 @@ func (c *Server) rotateIssuersV3() error {
return nil
}

// deleteIssuerKeys deletes v3 issuers keys that have ended more than the duration ago.
func (c *Server) deleteIssuerKeys(duration string) (int64, error) {
husobee marked this conversation as resolved.
Show resolved Hide resolved
result, err := c.db.Exec(`delete from v3_issuer_keys where issuer_id in (select issuer_id from v3_issuers where version = 3) and end_at < now() - $1::interval`, duration)
if err != nil {
return 0, fmt.Errorf("error deleting v3 issuer keys: %w", err)
}

rows, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("error deleting v3 issuer keys row affected: %w", err)
}

return rows, nil
}

// createIssuer - creation of a v3 issuer
func (c *Server) createV3Issuer(issuer Issuer) error {
defer incrementCounter(createIssuerCounter)
Expand Down Expand Up @@ -731,9 +795,11 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err
err error
)

logger.Debug("checking if v3")
if issuer.Version == 3 {
// get the duration from the issuer
if issuer.Duration != nil {
logger.Debug("making sure duration is not nil")
duration, err = timeutils.ParseDuration(*issuer.Duration)
if err != nil {
return fmt.Errorf("failed to parse issuer duration: %w", err)
Expand Down Expand Up @@ -762,13 +828,14 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err
start = &tmp
i = len(issuer.Keys)
}
logger.Debug("about to make the issuer keys")

valueFmtStr := ""

var keys []issuerKeys
var position = 0
// for i in buffer, create signing keys for each
for ; i < issuer.Buffer; i++ {
// Create signing keys for buffer and overlap
for ; i < issuer.Buffer+issuer.Overlap; i++ {
end := new(time.Time)
if duration != nil {
// start/end, increment every iteration
Expand Down Expand Up @@ -799,6 +866,7 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err
tx.Rollback()
return err
}
logger.Infof("iteration key pubkey: %+v", pubKeyTxt)

tmpStart := *start
tmpEnd := *end
Expand Down
Loading