Skip to content

Commit

Permalink
Merge branch 'main' into CAPPL-318
Browse files Browse the repository at this point in the history
  • Loading branch information
justinkaseman authored Dec 12, 2024
2 parents 53bd064 + 6a43e61 commit 7a5b2dd
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 20 deletions.
55 changes: 46 additions & 9 deletions pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ const (
DEVIATION_TYPE_ANY = "any"
// DEVIATION_TYPE_PERCENT is a numeric percentage difference
DEVIATION_TYPE_PERCENT = "percent"
// DEVIATION_TYPE_ABSOLUTE is a numeric absolute difference
// DEVIATION_TYPE_ABSOLUTE is a numeric unsigned difference
DEVIATION_TYPE_ABSOLUTE = "absolute"
REPORT_FORMAT_MAP = "map"
REPORT_FORMAT_ARRAY = "array"
REPORT_FORMAT_VALUE = "value"
MODE_QUORUM_OCR = "ocr"
MODE_QUORUM_ANY = "any"

DEFAULT_REPORT_FORMAT = REPORT_FORMAT_MAP
DEFAULT_OUTPUT_FIELD_NAME = "Reports"
DEFAULT_MODE_QUORUM = MODE_QUORUM_ANY
)

type ReduceAggConfig struct {
Expand Down Expand Up @@ -70,8 +73,12 @@ type AggregationField struct {
InputKey string `mapstructure:"inputKey" json:"inputKey"`
// How the data set should be aggregated to a single value
// * median - take the centermost value of the sorted data set of observations. can only be used on numeric types. not a true median, because no average if two middle values.
// * mode - take the most frequent value. if tied, use the "first".
// * mode - take the most frequent value. if tied, use the "first". use "ModeQuorom" to configure the minimum number of seen values.
Method string `mapstructure:"method" json:"method" jsonschema:"enum=median,enum=mode" required:"true"`
// When using Method=mode, this will configure the minimum number of values that must be seen
// * ocr - (default) enforces that the number of matching values must be at least f+1, otherwise consensus fails
// * any - do not enforce any limit on the minimum viable count. this may result in unexpected answers if every observation is unique.
ModeQuorum string `mapstructure:"modeQuorum" json:"modeQuorum,omitempty" jsonschema:"enum=ocr,enum=any" default:"ocr"`
// The key that the aggregated data is put under
// If omitted, the InputKey will be used
OutputKey string `mapstructure:"outputKey" json:"outputKey"`
Expand Down Expand Up @@ -108,7 +115,7 @@ func (a *reduceAggregator) Aggregate(lggr logger.Logger, previousOutcome *types.
return nil, fmt.Errorf("not enough observations provided %s, have %d want %d", field.InputKey, len(vals), 2*f+1)
}

singleValue, err := reduce(field.Method, vals)
singleValue, err := reduce(field.Method, vals, f, field.ModeQuorum)
if err != nil {
return nil, fmt.Errorf("unable to reduce on method %s, err: %s", field.Method, err.Error())
}
Expand Down Expand Up @@ -335,12 +342,20 @@ func (a *reduceAggregator) extractValues(lggr logger.Logger, observations map[oc
return vals
}

func reduce(method string, items []values.Value) (values.Value, error) {
func reduce(method string, items []values.Value, f int, modeQuorum string) (values.Value, error) {
switch method {
case AGGREGATION_METHOD_MEDIAN:
return median(items)
case AGGREGATION_METHOD_MODE:
return mode(items)
value, count, err := mode(items)
if err != nil {
return value, err
}
err = modeHasQuorum(modeQuorum, count, f)
if err != nil {
return value, err
}
return value, err
default:
// invariant, config should be validated
return nil, fmt.Errorf("unsupported aggregation method %s", method)
Expand Down Expand Up @@ -408,18 +423,18 @@ func toDecimal(item values.Value) (decimal.Decimal, error) {
}
}

func mode(items []values.Value) (values.Value, error) {
func mode(items []values.Value) (values.Value, int, error) {
if len(items) == 0 {
// invariant, as long as f > 0 there should be items
return nil, errors.New("items cannot be empty")
return nil, 0, errors.New("items cannot be empty")
}

counts := make(map[[32]byte]*counter)
for _, item := range items {
marshalled, err := proto.MarshalOptions{Deterministic: true}.Marshal(values.Proto(item))
if err != nil {
// invariant: values should always be able to be proto marshalled
return nil, err
return nil, 0, err
}
sha := sha256.Sum256(marshalled)
elem, ok := counts[sha]
Expand Down Expand Up @@ -449,7 +464,22 @@ func mode(items []values.Value) (values.Value, error) {

// If more than one mode found, choose first

return modes[0], nil
return modes[0], maxCount, nil
}

func modeHasQuorum(quorumType string, count int, f int) error {
switch quorumType {
case MODE_QUORUM_ANY:
return nil
case MODE_QUORUM_OCR:
if count < f+1 {
return fmt.Errorf("mode quorum not reached. have: %d, want: %d", count, f+1)
}
return nil
default:
// invariant, config should be validated
return fmt.Errorf("unsupported mode quorum %s", quorumType)
}
}

func deviation(method string, previousValue values.Value, nextValue values.Value) (decimal.Decimal, error) {
Expand Down Expand Up @@ -561,6 +591,13 @@ func ParseConfigReduceAggregator(config values.Map) (ReduceAggConfig, error) {
if len(field.Method) == 0 || !isOneOf(field.Method, []string{AGGREGATION_METHOD_MEDIAN, AGGREGATION_METHOD_MODE}) {
return ReduceAggConfig{}, fmt.Errorf("aggregation field must contain a method. options: [%s, %s]", AGGREGATION_METHOD_MEDIAN, AGGREGATION_METHOD_MODE)
}
if field.Method == AGGREGATION_METHOD_MODE && len(field.ModeQuorum) == 0 {
field.ModeQuorum = MODE_QUORUM_OCR
parsedConfig.Fields[i].ModeQuorum = MODE_QUORUM_OCR
}
if field.Method == AGGREGATION_METHOD_MODE && !isOneOf(field.ModeQuorum, []string{MODE_QUORUM_ANY, MODE_QUORUM_OCR}) {
return ReduceAggConfig{}, fmt.Errorf("mode quorum must be one of options: [%s, %s]", MODE_QUORUM_ANY, MODE_QUORUM_OCR)
}
if len(field.DeviationString) > 0 && isOneOf(field.DeviationType, []string{DEVIATION_TYPE_NONE, DEVIATION_TYPE_ANY}) {
return ReduceAggConfig{}, fmt.Errorf("aggregation field cannot have deviation with a deviation type of %s", field.DeviationType)
}
Expand Down
49 changes: 49 additions & 0 deletions pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,25 @@ func TestReduceAggregator_Aggregate(t *testing.T) {
return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}}
},
},
{
name: "reduce error mode with mode quorum of ocr",
previousOutcome: nil,
fields: []aggregators.AggregationField{
{
Method: "mode",
ModeQuorum: "ocr",
OutputKey: "Price",
},
},
extraConfig: map[string]any{},
observationsFactory: func() map[commontypes.OracleID][]values.Value {
mockValue, err := values.Wrap(true)
require.NoError(t, err)
mockValue2, err := values.Wrap(true)
require.NoError(t, err)
return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue2}}
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -825,6 +844,7 @@ func TestMedianAggregator_ParseConfig(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedId",
Method: "mode",
ModeQuorum: "ocr",
DeviationString: "1.1",
Deviation: decimal.NewFromFloat(1.1),
DeviationType: "absolute",
Expand Down Expand Up @@ -1153,6 +1173,23 @@ func TestMedianAggregator_ParseConfig(t *testing.T) {
return vMap
},
},
{
name: "invalid mode quorum",
configFactory: func() *values.Map {
vMap, err := values.NewMap(map[string]any{
"fields": []aggregators.AggregationField{
{
InputKey: "Price",
Method: "mode",
ModeQuorum: "invalid",
OutputKey: "Price",
},
},
})
require.NoError(t, err)
return vMap
},
},
}

for _, tt := range cases {
Expand Down Expand Up @@ -1233,6 +1270,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1278,6 +1316,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "BoolField",
OutputKey: "BoolField",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1323,6 +1362,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1368,6 +1408,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "BoolField",
OutputKey: "BoolField",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1413,6 +1454,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1458,6 +1500,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1503,6 +1546,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1548,6 +1592,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1593,6 +1638,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1638,6 +1684,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1683,6 +1730,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down Expand Up @@ -1728,6 +1776,7 @@ func TestAggregateShouldReport(t *testing.T) {
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
ModeQuorum: "any",
DeviationType: "any",
},
{
Expand Down
13 changes: 12 additions & 1 deletion pkg/workflows/secrets/secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"strings"

"golang.org/x/crypto/nacl/box"
)
Expand Down Expand Up @@ -146,13 +147,23 @@ func DecryptSecretsForNode(
return nil, err
}

if payload.WorkflowOwner != workflowOwner {
if normalizeOwner(payload.WorkflowOwner) != normalizeOwner(workflowOwner) {
return nil, fmt.Errorf("invalid secrets bundle: got owner %s, expected %s", payload.WorkflowOwner, workflowOwner)
}

return payload.Secrets, nil
}

func normalizeOwner(owner string) string {
o := owner
if strings.HasPrefix(o, "0x") {
o = o[2:]
}

o = strings.ToLower(o)
return o
}

func ValidateEncryptedSecrets(secretsData []byte, encryptionPublicKeys map[string][32]byte, workflowOwner string) error {
var encryptedSecrets EncryptedSecretsResult
err := json.Unmarshal(secretsData, &encryptedSecrets)
Expand Down
13 changes: 12 additions & 1 deletion pkg/workflows/secrets/secrets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"encoding/json"
"errors"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -31,7 +32,7 @@ var (
"SECRET_A": {"one", "two", "three", "four"},
"SECRET_B": {"all"},
}
workflowOwner = "0x9ed925d8206a4f88a2f643b28b3035b315753cd6"
workflowOwner = "0xFbb30BD8E9D779044c3c30dd82e52a5FA1573388"
config = SecretsConfig{
SecretsNames: map[string][]string{
"SECRET_A": {"ENV_VAR_A_FOR_NODE_ONE", "ENV_VAR_A_FOR_NODE_TWO", "ENV_VAR_A_FOR_NODE_THREE", "ENV_VAR_A_FOR_NODE_FOUR"},
Expand Down Expand Up @@ -162,6 +163,16 @@ func TestEncryptDecrypt(t *testing.T) {
assert.ErrorContains(t, err, "invalid secrets bundle: got owner")
})

t.Run("owner without 0x prefix", func(st *testing.T) {
_, err = DecryptSecretsForNode(result, k, workflowOwner[2:])
require.NoError(t, err)
})

t.Run("owner with lower casing", func(st *testing.T) {
_, err = DecryptSecretsForNode(result, k, strings.ToLower(workflowOwner))
require.NoError(t, err)
})

t.Run("key not in metadata", func(st *testing.T) {
overriddenResult := EncryptedSecretsResult{
EncryptedSecrets: encryptedSecrets,
Expand Down
22 changes: 17 additions & 5 deletions pkg/workflows/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func EncodeExecutionID(workflowID, eventID string) (string, error) {
return hex.EncodeToString(s.Sum(nil)), nil
}

func GenerateWorkflowIDFromStrings(owner string, workflow []byte, config []byte, secretsURL string) (string, error) {
func GenerateWorkflowIDFromStrings(owner string, name string, workflow []byte, config []byte, secretsURL string) (string, error) {
ownerWithoutPrefix := owner
if strings.HasPrefix(owner, "0x") {
ownerWithoutPrefix = owner[2:]
Expand All @@ -32,21 +32,29 @@ func GenerateWorkflowIDFromStrings(owner string, workflow []byte, config []byte,
return "", err
}

wid, err := GenerateWorkflowID(ownerb, workflow, config, secretsURL)
wid, err := GenerateWorkflowID(ownerb, name, workflow, config, secretsURL)
if err != nil {
return "", err
}

return hex.EncodeToString(wid[:]), nil
}

func GenerateWorkflowID(owner []byte, workflow []byte, config []byte, secretsURL string) ([32]byte, error) {
var (
versionByte = byte(0)
)

func GenerateWorkflowID(owner []byte, name string, workflow []byte, config []byte, secretsURL string) ([32]byte, error) {
s := sha256.New()
_, err := s.Write(owner)
if err != nil {
return [32]byte{}, err
}
_, err = s.Write([]byte(workflow))
_, err = s.Write([]byte(name))
if err != nil {
return [32]byte{}, err
}
_, err = s.Write(workflow)
if err != nil {
return [32]byte{}, err
}
Expand All @@ -58,5 +66,9 @@ func GenerateWorkflowID(owner []byte, workflow []byte, config []byte, secretsURL
if err != nil {
return [32]byte{}, err
}
return [32]byte(s.Sum(nil)), nil

sha := [32]byte(s.Sum(nil))
sha[0] = versionByte

return sha, nil
}
Loading

0 comments on commit 7a5b2dd

Please sign in to comment.