Skip to content

Commit

Permalink
updated cache input
Browse files Browse the repository at this point in the history
  • Loading branch information
awitas committed Nov 29, 2022
1 parent 00bbdfe commit 1391573
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 27 deletions.
9 changes: 1 addition & 8 deletions service/layers/dictionary.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/viant/mly/service/tfmodel"
"github.com/viant/mly/shared/common"
"hash"
"hash/fnv"
"sort"
"unsafe"
)
Expand Down Expand Up @@ -36,34 +35,28 @@ func Dictionary(session *tf.Session, graph *tf.Graph, signature *domain.Signatur
func DiscoverDictionary(session *tf.Session, graph *tf.Graph, layers []string) (*common.Dictionary, error) {
var result = &common.Dictionary{}
for _, name := range layers {
aHash := fnv.New64()
exported, err := tfmodel.Export(session, graph, name)
if err != nil {
return nil, err
}
layer := common.Layer{
Name: name,
}
hashValue := uint64(0)
switch vals := exported.(type) {
case []string:
layer.Strings = make([]string, len(vals))
copy(layer.Strings, vals)
sort.Strings(layer.Strings)
hashStrings(aHash, layer.Strings)
hashValue = aHash.Sum64()
case []int64:
layer.Ints = make([]int, len(vals))
copy(layer.Ints, *(*[]int)(unsafe.Pointer(&vals)))
sort.Ints(layer.Ints)
hashInts(aHash, layer.Ints)
hashValue = aHash.Sum64()
default:
return nil, fmt.Errorf("unsupported data type %T for %v", exported, name)
}
result.Layers = append(result.Layers, layer)
result.Hash += int(hashValue)
}
result.UpdateHash()
return result, nil
}

Expand Down
1 change: 1 addition & 0 deletions service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ func (s *Service) reloadIfNeeded(ctx context.Context) error {
s.config.DictMeta.Error = err.Error()
return err
}

}
s.dictionary = dictionary
s.config.DictMeta.Hash = dictionary.Hash
Expand Down
26 changes: 13 additions & 13 deletions shared/client/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ import (
type Service struct {
Config
sync.RWMutex
dict *Dictionary
gmetrics *gmetric.Service
counter *gmetric.Operation
datastore datastore.Storer
mux sync.RWMutex
messages Messages
poolErr error
hostIndex int64
newStorable func() common.Storable
dictRefresh int32
httpClient http.Client
dict *Dictionary
gmetrics *gmetric.Service
counter *gmetric.Operation
datastore datastore.Storer
mux sync.RWMutex
messages Messages
poolErr error
hostIndex int64
newStorable func() common.Storable
dictRefreshPending int32
httpClient http.Client
}

//NewMessage returns a new message
Expand Down Expand Up @@ -362,7 +362,7 @@ func (s *Service) Close() error {
}

func (s *Service) refreshMetadata() {
defer atomic.StoreInt32(&s.dictRefresh, 0)
defer atomic.StoreInt32(&s.dictRefreshPending, 0)
if err := s.loadModelDictionary(); err != nil {
log.Printf("failed to refresh meta data: %v", err)
}
Expand Down Expand Up @@ -423,7 +423,7 @@ func (s *Service) handleResponse(ctx context.Context, target interface{}, cached
func (s *Service) assertDictHash(response *Response) {
dict := s.dictionary()
if dict != nil && response.DictHash != dict.hash {
if atomic.CompareAndSwapInt32(&s.dictRefresh, 0, 1) {
if atomic.CompareAndSwapInt32(&s.dictRefreshPending, 0, 1) {
go s.refreshMetadata()
}
}
Expand Down
56 changes: 54 additions & 2 deletions shared/common/dictionary.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package common

import "github.com/viant/tapper/io"
import (
"encoding/binary"
"github.com/viant/tapper/io"
"hash"
"hash/fnv"
"sort"
"strings"
)

type (
//Dictionary represents model dictionary
Expand All @@ -15,11 +22,33 @@ type (
Strings []string
Ints []int
Floats []float32
Hash int
}

Layers []Layer
)

func (d *Dictionary) UpdateHash() int {
d.Hash = 0
for i := range d.Layers {
layer := &d.Layers[i]
aHash := fnv.New64()
if len(layer.Strings) > 0 {
sort.Strings(layer.Strings)
hashStrings(aHash, layer.Strings)
layer.Hash = int(aHash.Sum64())
} else if len(layer.Ints) > 0 {
sort.Ints(layer.Ints)
hashInts(aHash, layer.Ints)
layer.Hash = int(aHash.Sum64())
} else {
continue
}
d.Hash += layer.Hash
}
return d.Hash
}

func (l *Layers) Encoders() []io.Encoder {
var layers = make([]io.Encoder, len(*l))
for i := range *l {
Expand All @@ -32,7 +61,8 @@ func (l *Layer) Encode(stream io.Stream) {
stream.PutByte('\n')
stream.PutString("Name", l.Name)
if len(l.Strings) > 0 {
stream.PutStrings("Strings", l.Strings)
values := normalizeStrings(l.Strings)
stream.PutStrings("Strings", values)
}
if len(l.Ints) > 0 {
stream.PutInts("Ints", l.Ints)
Expand All @@ -44,5 +74,27 @@ func (l *Layer) Encode(stream io.Stream) {
}
stream.PutFloats("Floats", floats)
}
stream.PutInt("Hash", l.Hash)
}

func normalizeStrings(items []string) []string {
for i := range items {
s := items[i]
s = strings.ReplaceAll(s, "\"", "\\\"")
s = strings.ReplaceAll(s, "\n", "\\\n")
items[i] = s
}
return items
}

func hashStrings(hash hash.Hash, strings []string) {
for _, item := range strings {
hash.Write([]byte(item))
}
}

func hashInts(hash hash.Hash, ints []int) {
for _, item := range ints {
binary.Write(hash, binary.LittleEndian, item)
}
}
26 changes: 23 additions & 3 deletions tools/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,44 @@ func Run(args []string) {
}

func Discover(options *Options) error {
fs := afs.New()
if options.Operation == "dictHash" {
return discoverDictHash(options, fs)
}
model, err := loadModel(context.Background(), options.SourceURL)
if err != nil {
return err
}
fs := afs.New()
switch options.Operation {
case "signature":
return discoverSignature(options, model, fs)
case "layers":
return discoverLayers(options, model, fs)
case "config":
return discoverConfig(options, model, fs)

default:
return fmt.Errorf("unsupported option: '%v'", options.Operation)
}

}

func discoverDictHash(options *Options, fs afs.Service) error {
source, err := fs.DownloadWithURL(context.Background(), options.SourceURL)
if err != nil {
return err
}
dict := common.Dictionary{}
if err = sjson.Unmarshal(source, &dict); err != nil {
return err
}
fmt.Printf("dict hash: %v\n", dict.UpdateHash())
for _, l := range dict.Layers {
fmt.Printf("layer: %v hash: %v\n", l.Name, l.Hash)

}
return nil
}

const exportSuffix = "_lookup_index_table_lookup_table_export_values/LookupTableExportV2"

func discoverLayers(options *Options, model *tf.SavedModel, fs afs.Service) error {
Expand All @@ -109,7 +128,8 @@ func discoverLayers(options *Options, model *tf.SavedModel, fs afs.Service) erro
provider := msg.NewProvider(100*1024*1024, 1, json.New)
aMessage := provider.NewMessage()
dictLayers := common.Layers(dictionary.Layers)
aMessage.PutInt("Hash", 0)

aMessage.PutInt("Hash", dictionary.Hash)
aMessage.PutObjects("Layers", dictLayers.Encoders())
logger.Log(aMessage)
aMessage.Free()
Expand Down
2 changes: 1 addition & 1 deletion tools/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ type Options struct {
Mode string `short:"m" long:"mode" choice:"discover" choice:"run" description:"mode"`
SourceURL string `short:"s" long:"src" description:"source location"`
DestURL string `short:"d" long:"dest" description:"dest location"`
Operation string `short:"o" long:"opt" choice:"signature" choice:"layers" choice:"config"`
Operation string `short:"o" long:"opt" choice:"dictHash" choice:"signature" choice:"layers" choice:"config"`
ConfigURL string `short:"c" long:"config" `
}

Expand Down

0 comments on commit 1391573

Please sign in to comment.