Skip to content

Commit

Permalink
Merge branch 'fix-key-verification'
Browse files Browse the repository at this point in the history
* fix-key-verification:
  Add some more logging
  Add some logs.
  fix panic due to null pointer error
  • Loading branch information
roeierez committed Jan 20, 2025
2 parents de962da + 87c64ce commit fbaed20
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
20 changes: 12 additions & 8 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,34 @@ var ErrInternalError = fmt.Errorf("internal error")
var ErrInvalidSignature = fmt.Errorf("invalid signature")
var SignedMsgPrefix = []byte("realtimesync:")

func checkApiKey(config *config.Config, ctx context.Context, req interface{}) error {
func checkApiKey(config *config.Config, ctx context.Context, _ interface{}) error {
if config.CACert == nil || config.CACert.Raw == nil {
return nil
}

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return fmt.Errorf("Could not read request metadata")
return fmt.Errorf("could not read request metadata")
}

authHeader := md.Get("Authorization")[0]
authHeaders := md.Get("Authorization")
if len(authHeaders) == 0 {
return fmt.Errorf("invalid auth header")
}
authHeader := authHeaders[0]
if len(authHeader) <= 7 || !strings.HasPrefix(authHeader, "Bearer ") {
return fmt.Errorf("Invalid auth header")
return fmt.Errorf("invalid auth header")
}

apiKey := authHeader[7:]
block, err := base64.StdEncoding.DecodeString(apiKey)
if err != nil {
return fmt.Errorf("Could not decode auth header: %v", err)
return fmt.Errorf("could not decode auth header: %v", err)
}

cert, err := x509.ParseCertificate(block)
if err != nil {
return fmt.Errorf("Could not parse certificate: %v", err)
return fmt.Errorf("could not parse certificate: %v", err)
}

rootPool := x509.NewCertPool()
Expand All @@ -59,10 +63,10 @@ func checkApiKey(config *config.Config, ctx context.Context, req interface{}) er
Roots: rootPool,
})
if err != nil {
return fmt.Errorf("Certificate verification error: %v", err)
return fmt.Errorf("certificate verification error: %v", err)
}
if len(chains) != 1 || len(chains[0]) != 2 || !chains[0][0].Equal(cert) || !chains[0][1].Equal(config.CACert.Raw) {
return fmt.Errorf("Certificate verification error: invalid chain of trust")
return fmt.Errorf("certificate verification error: invalid chain of trust")
}

return nil
Expand Down
8 changes: 8 additions & 0 deletions syncer_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ func (s *PersistentSyncerServer) Start(quitChan chan struct{}) {
}

func (s *PersistentSyncerServer) SetRecord(ctx context.Context, msg *proto.SetRecordRequest) (*proto.SetRecordReply, error) {
log.Println("SetRecord: started")
c, err := middleware.Authenticate(s.config, ctx, msg)
if err != nil {
log.Printf("SetRecord completed with auth error: %v\n", err)
return nil, err
}
pubkey := c.Value(middleware.USER_PUBKEY_CONTEXT_KEY).(string)
log.Printf("SetRecord: pubkey: %v\n", pubkey)
newRevision, err := s.storage.SetRecord(c, pubkey, msg.Record.Id, msg.Record.Data, msg.Record.Revision, msg.Record.SchemaVersion)

if err != nil {
Expand All @@ -81,18 +84,22 @@ func (s *PersistentSyncerServer) SetRecord(ctx context.Context, msg *proto.SetRe
newRecord := msg.Record
newRecord.Revision = newRevision
s.eventsManager.notifyChange(c.Value(middleware.USER_PUBKEY_CONTEXT_KEY).(string), newRecord)
log.Println("SetRecord: finished")
return &proto.SetRecordReply{
Status: proto.SetRecordStatus_SUCCESS,
NewRevision: newRevision,
}, nil
}

func (s *PersistentSyncerServer) ListChanges(ctx context.Context, msg *proto.ListChangesRequest) (*proto.ListChangesReply, error) {
log.Println("ListChanges: started")
c, err := middleware.Authenticate(s.config, ctx, msg)
if err != nil {
log.Printf("ListChanges completed with auth error: %v\n", err)
return nil, err
}
pubkey := c.Value(middleware.USER_PUBKEY_CONTEXT_KEY).(string)
log.Printf("ListChanges: pubkey: %v\n", pubkey)
changed, err := s.storage.ListChanges(c, pubkey, msg.SinceRevision)
if err != nil {
return nil, err
Expand All @@ -106,6 +113,7 @@ func (s *PersistentSyncerServer) ListChanges(ctx context.Context, msg *proto.Lis
SchemaVersion: r.SchemaVersion,
}
}
log.Printf("ListChanges: finished with %v records\n", len(records))
return &proto.ListChangesReply{
Changes: records,
}, nil
Expand Down

0 comments on commit fbaed20

Please sign in to comment.