diff --git a/slashing/slashing.go b/slashing/slashing.go new file mode 100644 index 00000000..e817b341 --- /dev/null +++ b/slashing/slashing.go @@ -0,0 +1,107 @@ +package slashing + +import ( + "fmt" + "github.com/ethereum/go-ethereum/common" + "strings" +) + +type StakeSource string + +const ( + StakeSourceSlashable StakeSource = "slashable" + StakeSourceNonSlashable StakeSource = "non-slashable" + StakeSourceTotal StakeSource = "total" +) + +type OperatorSet struct { + AvsAddress common.Address + OperatorSetId int32 +} + +type State struct { + TotalMagnitude uint64 + OperatorSets []OperatorSet + SlashableMagnitudes []uint64 +} + +func (s *State) Print() { + fmt.Println(strings.Repeat("-", 60)) + fmt.Println("Total Magnitude: ", s.TotalMagnitude) + fmt.Println() + for i, operatorSet := range s.OperatorSets { + fmt.Println("Operator Set: ", i) + fmt.Println("AVS Address: ", operatorSet.AvsAddress.Hex()) + fmt.Println("Operator Set ID: ", operatorSet.OperatorSetId) + fmt.Println("Magnitude: ", s.SlashableMagnitudes[i]) + fmt.Println() + } + fmt.Println(strings.Repeat("-", 60)) +} + +func CalculateMagnitudes( + currentState State, + operatorSet OperatorSet, + stakeSource StakeSource, + bips float64, + debug bool, +) (*State, error) { + if debug { + currentState.Print() + } + slashableProportion := bips / 10_000 + if stakeSource == StakeSourceSlashable { + currentSlashableProportion := float64(sumUInt64Array(currentState.SlashableMagnitudes)) / float64(currentState.TotalMagnitude) + if currentSlashableProportion <= slashableProportion { + return nil, fmt.Errorf("slashable proportion to allocate %f is less than or equal to available slashable proportion %f", slashableProportion, currentSlashableProportion) + } + + nonSlashableProportion := (float64(currentState.TotalMagnitude) - float64(sumUInt64Array(currentState.SlashableMagnitudes))) / float64(currentState.TotalMagnitude) + if debug { + fmt.Println("nonSlashableProportion: ", nonSlashableProportion) + } + + allocatedMagnitude := float64(sumUInt64Array(currentState.SlashableMagnitudes)) + if debug { + fmt.Println("allocatedMagnitude: ", allocatedMagnitude) + } + + totalMagnitudeNew := allocatedMagnitude / (1 - slashableProportion - nonSlashableProportion) + if debug { + fmt.Println("totalMagnitudeNew: ", totalMagnitudeNew) + } + + slashableMagnitude := slashableProportion * totalMagnitudeNew + if debug { + fmt.Println("slashableMagnitude: ", slashableMagnitude) + } + + nonSlashableMagnitude := nonSlashableProportion * totalMagnitudeNew + if debug { + fmt.Println("nonSlashableMagnitude: ", nonSlashableMagnitude) + } + + opSetToUpdate := []OperatorSet{operatorSet} + slashableMagnitudeSet := []uint64{uint64(slashableMagnitude)} + + newState := &State{ + TotalMagnitude: uint64(totalMagnitudeNew), + OperatorSets: opSetToUpdate, + SlashableMagnitudes: slashableMagnitudeSet, + } + if debug { + newState.Print() + } + return newState, nil + } else { + return nil, fmt.Errorf("unimplemented stake source %s", stakeSource) + } +} + +func sumUInt64Array(arr []uint64) uint64 { + sum := uint64(0) + for _, v := range arr { + sum += v + } + return sum +} diff --git a/slashing/slashing_test.go b/slashing/slashing_test.go new file mode 100644 index 00000000..9bf5d51c --- /dev/null +++ b/slashing/slashing_test.go @@ -0,0 +1,146 @@ +package slashing + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + + "github.com/stretchr/testify/assert" +) + +func TestCalculateNewState(t *testing.T) { + + var tests = []struct { + name string + oldState State + stakeSource StakeSource + operatorSet OperatorSet + bips float64 + newState *State + debug bool + expectErr bool + }{ + { + name: "Simple case where stake is sourced from slashable stake. A new operator set is added", + oldState: State{ + TotalMagnitude: 10, + OperatorSets: []OperatorSet{ + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 1, + }, + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 2, + }, + }, + SlashableMagnitudes: []uint64{1, 1}, + }, + stakeSource: StakeSourceSlashable, + operatorSet: OperatorSet{ + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 4, + }, + bips: 1000, + newState: &State{ + TotalMagnitude: 20.0, + OperatorSets: []OperatorSet{ + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 4, + }, + }, + SlashableMagnitudes: []uint64{2}, + }, + }, + { + name: "Simple case where stake is sourced from slashable stake. Existing operator set is updated", + oldState: State{ + TotalMagnitude: 10, + OperatorSets: []OperatorSet{ + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 1, + }, + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 2, + }, + }, + SlashableMagnitudes: []uint64{1, 1}, + }, + stakeSource: StakeSourceSlashable, + operatorSet: OperatorSet{ + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 1, + }, + bips: 1500, + newState: &State{ + TotalMagnitude: 40, + OperatorSets: []OperatorSet{ + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 1, + }, + }, + SlashableMagnitudes: []uint64{6}, + }, + }, + { + name: "New slashable stake is equal to available slashable stake so it should fail", + oldState: State{ + TotalMagnitude: 10, + OperatorSets: []OperatorSet{ + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 1, + }, + { + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 2, + }, + }, + SlashableMagnitudes: []uint64{1, 1}, + }, + stakeSource: StakeSourceSlashable, + operatorSet: OperatorSet{ + AvsAddress: common.HexToAddress("0xabc"), + OperatorSetId: 4, + }, + bips: 2000, + expectErr: true, + }, + //{ + // name: "Simple case where stake is sourced from non slashable stake", + // oldState: State{ + // TotalMagnitude: 10, + // OperatorSets: []int{1, 2}, + // SlashableMagnitudes: []float64{1, 1}, + // }, + // stakeSource: StakeSourceNonSlashable, + // operatorSet: 4, + // bips: 100, + // newState: &State{ + // TotalMagnitude: 10.0, + // OperatorSets: []int{4}, + // SlashableMagnitudes: []float64{1.0}, + // }, + //}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + newState, err := CalculateMagnitudes(tt.oldState, tt.operatorSet, tt.stakeSource, tt.bips, tt.debug) + if tt.expectErr { + assert.Error(t, err) + + } else { + assert.NoError(t, err) + assert.Equal(t, tt.newState, newState) + } + + }) + + } + +}