Skip to content

Commit

Permalink
Merge pull request #187 from zama-ai/davidk/ciphertext-preload
Browse files Browse the repository at this point in the history
feat: implement ciphertext cache preload upon restart
  • Loading branch information
david-zk authored Dec 13, 2024
2 parents 2670d71 + 7fdd54a commit 6be5c69
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 9 deletions.
136 changes: 127 additions & 9 deletions fhevm-engine/fhevm-go-native/fhevm/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ type ExecutorApi interface {
// We pass current block number to know at which
// block ciphertext should be materialized inside blockchain state.
CreateSession(blockNumber int64) ExecutorSession
// Preload ciphertexts into cache and perform initial computations,
// should be called once after blockchain node initialization
PreloadCiphertexts(blockNumber int64, api ChainStorageApi) error
}

type SegmentId int
Expand Down Expand Up @@ -230,6 +233,114 @@ func (executorApi *ApiImpl) CreateSession(blockNumber int64) ExecutorSession {
}
}

func (executorApi *ApiImpl) PreloadCiphertexts(blockNumber int64, api ChainStorageApi) error {
computations := executorApi.loadComputationsFromStateToCache(blockNumber, api)
if computations > 0 {
return executorProcessPendingComputations(executorApi)
}

return nil
}

func (executorApi *ApiImpl) loadComputationsFromStateToCache(startBlockNumber int64, api ChainStorageApi) int {
loadStartTime := time.Now()
computations := 0
defer func() {
duration := time.Since(loadStartTime)
fmt.Printf("ciphertext cache preloaded with %d ciphertexts in %dms\n", computations, duration.Milliseconds())
}()

// TODO: figure out the limit how long in future blocks we should preload
lastBlockToPreload := startBlockNumber + 30

executorApi.cache.lock.Lock()
defer executorApi.cache.lock.Unlock()

for block := startBlockNumber; block < lastBlockToPreload; block++ {
countAddress := blockNumberToQueueItemCountAddress(block)
ciphertextsInBlock := api.GetState(executorApi.contractStorageAddress, countAddress).Big()
inBlock := ciphertextsInBlock.Int64()
queue := make([]*ComputationToInsert, 0)
enqueuedCiphertext := make(map[string]bool)

if inBlock == 0 {
continue
}

computations += int(inBlock)

for ctNum := 0; ctNum < int(inBlock); ctNum++ {
layout := blockQueueStorageLayout(block, int64(ctNum))
metadata := bytesToMetadata(api.GetState(executorApi.contractStorageAddress, layout.metadata))
outputHandle := api.GetState(executorApi.contractStorageAddress, layout.outputHandle)
computation := &ComputationToInsert{
segmentId: 0,
Operation: metadata.Operation,
OutputHandle: outputHandle[:],
CommitBlockId: block,
}

if isBinaryOp(metadata.Operation) {
firstOpHandle := api.GetState(executorApi.contractStorageAddress, layout.firstOperand)
firstOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, firstOpHandle)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: firstOpHandle[:],
CompressedCiphertext: firstOpCt,
FheUintType: handleType(firstOpHandle[:]),
})

if metadata.IsBigScalar {
// TODO: implement big scalar
} else if metadata.IsScalar {
secondOpHandle := api.GetState(executorApi.contractStorageAddress, layout.secondOperand)
computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: true,
Handle: secondOpHandle[:],
FheUintType: handleType(firstOpHandle[:]),
})
} else {
secondOpHandle := api.GetState(executorApi.contractStorageAddress, layout.secondOperand)
secondOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, secondOpHandle)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: secondOpHandle[:],
CompressedCiphertext: secondOpCt,
FheUintType: handleType(secondOpHandle[:]),
})
}
} else if isUnaryOp(metadata.Operation) {
firstOpAddress := api.GetState(executorApi.contractStorageAddress, layout.firstOperand)
firstOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, firstOpAddress)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: firstOpAddress[:],
CompressedCiphertext: firstOpCt,
FheUintType: handleType(firstOpAddress[:]),
})
} else {
// TODO: handle all special functions to load their ciphertext arguments
}

if !enqueuedCiphertext[string(computation.OutputHandle)] {
queue = append(queue, computation)
enqueuedCiphertext[string(computation.OutputHandle)] = true
}
}

ctsToCompute := &BlockCiphertextQueue{
queue: queue,
enqueuedCiphertext: enqueuedCiphertext,
}
executorApi.cache.ciphertextsToCompute[block] = ctsToCompute
}

return computations
}

func (sessionApi *SessionImpl) Commit(blockNumber int64, storage ChainStorageApi) error {
err := sessionApi.sessionStore.Commit(storage)
if err != nil {
Expand Down Expand Up @@ -530,12 +641,13 @@ func (dbApi *EvmStorageComputationStore) InsertComputationBatch(evmStorage Chain

for _, comp := range bucket {
// don't have duplicates, from possibly evaluating multiple trie caches
if !ctsStorage.enqueuedCiphertext[common.Bytes2Hex(comp.OutputHandle)] {
if !ctsStorage.enqueuedCiphertext[string(comp.OutputHandle)] {
// we must fill the raw ciphertext values here from storage so cache
// would have ciphertexts to compute on, as cache doesn't have easy
// access to the evm state
dbApi.hydrateComputationFromEvmState(evmStorage, comp)
ctsStorage.queue = append(ctsStorage.queue, comp)
ctsStorage.enqueuedCiphertext[string(comp.OutputHandle)] = true
}
}
}
Expand Down Expand Up @@ -766,18 +878,20 @@ func InitExecutor() (ExecutorApi, error) {

workAvailableChan := make(chan bool, 10)

cache := &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]*CacheBlockData),
ciphertextsToCompute: make(map[int64]*BlockCiphertextQueue),
workAvailableChan: workAvailableChan,
lastCacheGc: time.Now(),
}

apiImpl := &ApiImpl{
address: fhevmContractAddress,
aclContractAddress: aclContractAddress,
contractStorageAddress: storageAddress,
executorUrl: executorUrl,
cache: &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]*CacheBlockData),
ciphertextsToCompute: make(map[int64]*BlockCiphertextQueue),
workAvailableChan: workAvailableChan,
lastCacheGc: time.Now(),
},
cache: cache,
}

// run executor worker in the background
Expand Down Expand Up @@ -885,8 +999,12 @@ func executorProcessPendingComputations(impl *ApiImpl) error {
if err != nil {
return err
}
ciphertexts := response.GetResultCiphertexts()
if ciphertexts == nil {
return errors.New(response.GetError().String())
}

outCts := response.GetResultCiphertexts().Ciphertexts
outCts := ciphertexts.Ciphertexts
fmt.Printf("got %d ciphertext responses from the executor\n", len(outCts))
for _, ct := range outCts {
theBlock, exists := ctToBlockIndex[string(ct.Handle)]
Expand Down
22 changes: 22 additions & 0 deletions fhevm-engine/fhevm-go-native/fhevm/fhelib_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,25 @@ func getThreeFheOperands(sess ExecutorSession, input []byte) (first []byte, seco

return input[0:32], input[32:64], input[64:96], nil
}

func isBinaryOp(op FheOp) bool {
switch op {
case FheAdd, FheBitAnd, FheBitOr, FheBitXor, FheDiv, FheEq, FheGe, FheGt, FheLe, FheLt, FheMax, FheMin, FheMul, FheNe, FheRem, FheRotl, FheRotr, FheShl, FheShr, FheSub:
return true
case FheCast, FheNeg, FheNot, FheRand, FheRandBounded, FheIfThenElse, TrivialEncrypt:
return false
default:
return false
}
}

func isUnaryOp(op FheOp) bool {
switch op {
case FheNeg, FheNot:
return true
case FheAdd, FheBitAnd, FheBitOr, FheBitXor, FheDiv, FheEq, FheGe, FheGt, FheLe, FheLt, FheMax, FheMin, FheMul, FheNe, FheRem, FheRotl, FheRotr, FheShl, FheShr, FheSub, FheCast, FheRand, FheRandBounded, FheIfThenElse, TrivialEncrypt:
return false
default:
return false
}
}

0 comments on commit 6be5c69

Please sign in to comment.