Skip to content

Commit

Permalink
patched duplicated output, added dictionary meta
Browse files Browse the repository at this point in the history
  • Loading branch information
awitas committed Nov 29, 2022
1 parent a564e07 commit 00bbdfe
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
9 changes: 9 additions & 0 deletions service/config/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/viant/tapper/config"
"os"
"path"
"time"
)

//Model represents model config
Expand All @@ -25,6 +26,14 @@ type Model struct {
Modified *Modified `json:",omitempty" yaml:",omitempty"`
Stream *config.Stream `json:",omitempty" yaml:",omitempty"`
shared.MetaInput `json:",omitempty" yaml:",inline"`
DictMeta DictionaryMeta
}

//DictionaryMeta represents dictionary meta
type DictionaryMeta struct {
Hash int
Reloaded time.Time
Error string
}

//UseDictionary returns true if dictionary can be used
Expand Down
12 changes: 10 additions & 2 deletions service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,22 @@ func (s *Service) reloadIfNeeded(ctx context.Context) error {
var dictionary *common.Dictionary
s.reconcileSignatureWithInput(signature)
if s.config.UseDictionary() {
s.config.DictMeta.Error = ""
if s.config.DictURL != "" {
if dictionary, err = s.loadDictionary(ctx, s.config.DictURL); err != nil {
s.config.DictMeta.Error = err.Error()
return err
}
} else {
dictionary, err = layers.Dictionary(model.Session, model.Graph, signature)
if err != nil {
s.config.DictMeta.Error = err.Error()
return err
}

}
s.dictionary = dictionary
s.config.DictMeta.Hash = dictionary.Hash
s.config.DictMeta.Reloaded = time.Now()
}

var inputs = make(map[string]*domain.Input)
Expand Down Expand Up @@ -576,9 +580,13 @@ func (srv *Service) reconcileSignatureWithInput(signature *domain.Signature) {
}
delete(byName, field.Name)
}

if len(signature.Outputs) > 0 {
outputIndex := srv.config.OutputIndex()
for _, output := range signature.Outputs {

if _, has := outputIndex[output.Name]; has {
continue
}
field := &shared.Field{Name: output.Name, DataType: output.DataType}
if field.DataType == "" {
field.SetRawType(reflect.TypeOf(""))
Expand Down
4 changes: 1 addition & 3 deletions service/tfmodel/signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func Signature(model *tf.SavedModel) (*domain.Signature, error) {
result := &domain.Signature{
Method: signature.MethodName,
}

for k, v := range signature.Outputs {
output := domain.Output{}
output.Name = k
Expand All @@ -34,13 +33,12 @@ func Signature(model *tf.SavedModel) (*domain.Signature, error) {
tryAssignDataType(v, output)
result.Outputs = append(result.Outputs, output)
}
result.Output = result.Outputs[0]

result.Output = result.Outputs[0]
var inputs = make([]string, 0, len(signature.Inputs))
for k := range signature.Inputs {
inputs = append(inputs, k)
}

sort.Strings(inputs)
for _, k := range inputs {
v := signature.Inputs[k]
Expand Down
11 changes: 11 additions & 0 deletions shared/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ type (
}
)

func (m *MetaInput) OutputIndex() map[string]int {
var outputIndex = map[string]int{}
if len(m.Outputs) == 0 {
return outputIndex
}
for i, f := range m.Outputs {
outputIndex[f.Name] = i
}
return outputIndex
}

func (d *MetaInput) KeysLen() int {
return len(d.Inputs)
}
Expand Down

0 comments on commit 00bbdfe

Please sign in to comment.