Skip to content

Commit

Permalink
feat: add validation agent functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
fritterhoff committed Aug 18, 2023
1 parent e5f3309 commit f86ce26
Show file tree
Hide file tree
Showing 12 changed files with 473 additions and 15 deletions.
43 changes: 42 additions & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,45 @@ jobs:
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
file: docker/Dockerfile
file: docker/Dockerfile
build-agent:
runs-on: ubuntu-latest
env:
# Use docker.io for Docker Hub if empty
REGISTRY: ghcr.io
# github.repository as <account>/<repo>
IMAGE_NAME: ${{ github.repository }}
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v2
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v1
- name: Log into registry ${{ env.REGISTRY }}
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v3
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-agent
tags: |
type=schedule
type=ref,event=branch
type=ref,event=tag
type=ref,event=pr
type=raw,value={{branch}}-{{sha}}-{{date 'X'}},enable=${{ github.event_name != 'pull_request' }}
- name: Build and push
uses: docker/build-push-action@v2
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
file: docker/Dockerfile.agent
37 changes: 33 additions & 4 deletions acme/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/fxamacker/cbor/v2"
"github.com/google/go-tpm/tpm2"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"

"github.com/smallstep/go-attestation/attest"
Expand All @@ -36,6 +37,7 @@ import (
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"

"github.com/smallstep/certificates/acme/validation"
"github.com/smallstep/certificates/authority/provisioner"
)

Expand Down Expand Up @@ -117,11 +119,38 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
if InsecurePortHTTP01 != 0 {
u.Host += ":" + strconv.Itoa(InsecurePortHTTP01)
}

go func() {
mqtt, ok := validation.FromContext(ctx)
if !ok {
return
}
req := validation.ValidationRequest{
Authz: ch.AuthorizationID,
Challenge: ch.ID,
Target: u.String(),
}
data, err := json.Marshal(req)
if err != nil {
return
}
if token := mqtt.GetClient().Publish(fmt.Sprintf("%s/jobs", mqtt.GetOrganization()), 1, false, data); token.Wait() && token.Error() != nil {
logrus.Warn(token.Error())
}
}()
vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String())
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
resp, errHttp := vc.Get(u.String())
// get challenge again and check if it was already validated
chDb, errDb := db.GetChallenge(ctx, ch.ID, ch.AuthorizationID)
if errDb == nil {
logrus.WithField("challenge", chDb.ID).WithField("authz", chDb.AuthorizationID).Infof("challenge has status %s", chDb.Status)
if chDb.Status == StatusValid {
return nil
}
} else {
logrus.WithError(errDb).WithField("challenge", ch.ID).WithField("authz", ch.AuthorizationID).Warn("error getting challenge from db")
}
if errHttp != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, errHttp,
"error doing http GET for url %s", u))
}
defer resp.Body.Close()
Expand Down
100 changes: 100 additions & 0 deletions acme/mqtt/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package mqtt

import (
"context"
"encoding/json"
"fmt"
"net/url"
"time"

mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/sirupsen/logrus"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/acme/validation"
)

var clock acme.Clock

func Connect(acmeDB acme.DB, host, user, password, organization string) (validation.MqttClient, error) {
opts := mqtt.NewClientOptions()
opts.SetOrderMatters(false) // Allow out of order messages (use this option unless in order delivery is essential)
opts.ConnectTimeout = time.Second // Minimal delays on connect
opts.WriteTimeout = time.Second // Minimal delays on writes
opts.KeepAlive = 10 // Keepalive every 10 seconds so we quickly detect network outages
opts.PingTimeout = time.Second // local broker so response should be quick
opts.ConnectRetry = true
opts.AutoReconnect = true
opts.ClientID = "acme"
opts.Username = user
opts.Password = password
opts.AddBroker(fmt.Sprintf("ssl://%s:8883", host))
logrus.Infof("connecting to mqtt broker")
// Log events
opts.OnConnectionLost = func(cl mqtt.Client, err error) {
logrus.Println("mqtt connection lost")
}
opts.OnConnect = func(mqtt.Client) {
logrus.Println("mqtt connection established")
}
opts.OnReconnecting = func(mqtt.Client, *mqtt.ClientOptions) {
logrus.Println("mqtt attempting to reconnect")
}

client := mqtt.NewClient(opts)

if token := client.Connect(); token.WaitTimeout(30*time.Second) && token.Error() != nil {
logrus.Warn(token.Error())
return nil, token.Error()
}

go func() {
client.Subscribe(fmt.Sprintf("%s/data", organization), 1, func(client mqtt.Client, msg mqtt.Message) {
logrus.Printf("Received message on topic: %s\nMessage: %s\n", msg.Topic(), msg.Payload())
ctx := context.Background()
data := msg.Payload()
var payload validation.ValidationResponse
err := json.Unmarshal(data, &payload)
if err != nil {
logrus.Errorf("error unmarshalling payload: %v", err)
return
}

ch, err := acmeDB.GetChallenge(ctx, payload.Challenge, payload.Authz)
if err != nil {
logrus.Errorf("error getting challenge: %v", err)
return
}

acc, err := acmeDB.GetAccount(ctx, ch.AccountID)
if err != nil {
logrus.Errorf("error getting account: %v", err)
return
}
expected, err := acme.KeyAuthorization(ch.Token, acc.Key)

if payload.Content != expected || err != nil {
logrus.Errorf("invalid key authorization: %v", err)
return
}
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
logrus.Infof("challenge %s validated using mqtt", u.String())

if ch.Status != acme.StatusPending && ch.Status != acme.StatusValid {
return
}

ch.Status = acme.StatusValid
ch.Error = nil
ch.ValidatedAt = clock.Now().Format(time.RFC3339)

if err = acmeDB.UpdateChallenge(ctx, ch); err != nil {
logrus.Errorf("error updating challenge: %v", err)
} else {
logrus.Infof("challenge %s updated to valid", u.String())
}

})
}()
connection := validation.BrokerConnection{Client: client, Organization: organization}
return connection, nil
}
60 changes: 60 additions & 0 deletions acme/validation/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package validation

import (
"context"

mqtt "github.com/eclipse/paho.mqtt.golang"
)

type ValidationResponse struct {
Authz string `json:"authz"`
Challenge string `json:"challenge"`
Content string `json:"content"`
}

type ValidationRequest struct {
Authz string `json:"authz"`
Challenge string `json:"challenge"`
Target string `json:"target"`
}

type validationKey struct{}

type MqttClient interface {
GetClient() mqtt.Client
GetOrganization() string
}

type BrokerConnection struct {
Client mqtt.Client
Organization string
}

func (b BrokerConnection) GetClient() mqtt.Client {
return b.Client
}

func (b BrokerConnection) GetOrganization() string {
return b.Organization
}

// NewContext adds the given validation client to the context.
func NewContext(ctx context.Context, a MqttClient) context.Context {
return context.WithValue(ctx, validationKey{}, a)
}

// FromContext returns the validation client from the given context.
func FromContext(ctx context.Context) (a MqttClient, ok bool) {
a, ok = ctx.Value(validationKey{}).(MqttClient)
return
}

// MustFromContext returns the validation client from the given context. It will
// panic if no validation client is not in the context.
func MustFromContext(ctx context.Context) MqttClient {
if a, ok := FromContext(ctx); !ok {
panic("validation client is not in the context")
} else {
return a
}
}
8 changes: 8 additions & 0 deletions authority/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ var (
}
)

type MqttConfig struct {
Host string `json:"host"`
Username string `json:"username"`
Password string `json:"password"`
Organization string `json:"organization"`
}

// Config represents the CA configuration and it's mapped to a JSON object.
type Config struct {
Root multiString `json:"root"`
Expand All @@ -87,6 +94,7 @@ type Config struct {
SkipValidation bool `json:"-"`
Storage string `json:"storage,omitempty"`
ManagementHost string `json:"managementHost"`
ValidationBroker *MqttConfig `json:"validationBroker,omitempty"`

// Keeps record of the filename the Config is read from
loadedFromFilepath string
Expand Down
2 changes: 1 addition & 1 deletion ca/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server {
if err != nil {
panic(err)
}
baseContext := buildContext(ca.auth, nil, nil, nil, nil)
baseContext := buildContext(ca.auth, nil, nil, nil, nil, nil)
srv.Config.Handler = ca.srv.Handler
srv.Config.BaseContext = func(net.Listener) context.Context {
return baseContext
Expand Down
24 changes: 22 additions & 2 deletions ca/ca.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"

"github.com/sirupsen/logrus"

"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
acmeMqtt "github.com/smallstep/certificates/acme/mqtt"
"github.com/smallstep/certificates/acme/validation"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
Expand Down Expand Up @@ -360,7 +364,20 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
if err != nil {
return nil, errors.Wrap(err, "error connecting to EAB")
}
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker, client)
var validationBroker validation.MqttClient
if cfg.ValidationBroker != nil {
if cfg.ValidationBroker.Password == "" {
// pick password from env
cfg.ValidationBroker.Password = os.Getenv("MQTT_PASSWORD")
}

validationBroker, err = acmeMqtt.Connect(acmeDB, cfg.ValidationBroker.Host, cfg.ValidationBroker.Username, cfg.ValidationBroker.Password, cfg.ValidationBroker.Organization)
if err != nil {
logrus.Warn("error connecting to validation broker. Only local validation will be available!")
}
}

baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker, client, validationBroker)

ca.srv = server.New(cfg.Address, handler, tlsConfig)
ca.srv.BaseContext = func(net.Listener) context.Context {
Expand Down Expand Up @@ -407,7 +424,7 @@ func (ca *CA) shouldServeInsecureServer() bool {
}

// buildContext builds the server base context.
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker, eabClient pb.EABServiceClient) context.Context {
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker, eabClient pb.EABServiceClient, validationBroker validation.MqttClient) context.Context {
ctx := authority.NewContext(context.Background(), a)
if authDB := a.GetDatabase(); authDB != nil {
ctx = db.NewContext(ctx, authDB)
Expand All @@ -424,6 +441,9 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
if eabClient != nil {
ctx = eab.NewContext(ctx, eabClient)
}
if validationBroker != nil {
ctx = validation.NewContext(ctx, validationBroker)
}
return ctx
}

Expand Down
2 changes: 1 addition & 1 deletion ca/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func startCATestServer() *httptest.Server {
panic(err)
}
// Use a httptest.Server instead
baseContext := buildContext(ca.auth, nil, nil, nil, nil)
baseContext := buildContext(ca.auth, nil, nil, nil, nil, nil)
srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler)
return srv
}
Expand Down
Loading

0 comments on commit f86ce26

Please sign in to comment.