diff --git a/resolver/enum.go b/resolver/enum.go index ee96fc2b..f6fa160e 100644 --- a/resolver/enum.go +++ b/resolver/enum.go @@ -15,10 +15,18 @@ func NewEnumType(enum *Enum, repeated bool) *Type { } func (e *Enum) HasValue(name string) bool { + if e == nil { + return false + } + return e.Value(name) != nil } func (e *Enum) Value(name string) *EnumValue { + if e == nil { + return nil + } + if strings.Contains(name, ".") { enumFQDNPrefix := e.FQDN() + "." if !strings.HasPrefix(name, enumFQDNPrefix) { @@ -48,6 +56,9 @@ func (e *Enum) AttributeMap() map[string][]*EnumValue { } func (e *Enum) Package() *Package { + if e == nil { + return nil + } if e.File == nil { return nil } @@ -55,6 +66,9 @@ func (e *Enum) Package() *Package { } func (e *Enum) GoPackage() *GoPackage { + if e == nil { + return nil + } if e.File == nil { return nil } @@ -62,6 +76,9 @@ func (e *Enum) GoPackage() *GoPackage { } func (e *Enum) PackageName() string { + if e == nil { + return "" + } pkg := e.Package() if pkg == nil { return "" @@ -70,5 +87,9 @@ func (e *Enum) PackageName() string { } func (e *EnumExpr) ReferenceNames() []string { + if e == nil { + return nil + } + return e.By.ReferenceNames() } diff --git a/resolver/fqdn.go b/resolver/fqdn.go index 036a6dd4..e14683f1 100644 --- a/resolver/fqdn.go +++ b/resolver/fqdn.go @@ -6,14 +6,23 @@ import ( ) func (s *Service) FQDN() string { + if s == nil { + return "" + } return fmt.Sprintf("%s.%s", s.PackageName(), s.Name) } func (m *Method) FQDN() string { + if m == nil { + return "" + } return fmt.Sprintf("%s/%s", m.Service.FQDN(), m.Name) } func (m *Message) FQDN() string { + if m == nil { + return "" + } return strings.Join( append(append([]string{m.PackageName()}, m.ParentMessageNames()...), m.Name), ".", @@ -21,6 +30,9 @@ func (m *Message) FQDN() string { } func (f *Field) FQDN() string { + if f == nil { + return "" + } if f.Message == nil { return f.Name } @@ -28,6 +40,9 @@ func (f *Field) FQDN() string { } func (f *OneofField) FQDN() string { + if f == nil { + return "" + } return fmt.Sprintf("%s.%s", f.Oneof.Message.FQDN(), f.Name) } @@ -42,10 +57,16 @@ func (e *Enum) FQDN() string { } func (v *EnumValue) FQDN() string { + if v == nil { + return "" + } return fmt.Sprintf("%s.%s", v.Enum.FQDN(), v.Value) } func (t *Type) FQDN() string { + if t == nil { + return "" + } var repeated string if t.Repeated { repeated = "repeated " @@ -66,5 +87,8 @@ func (t *Type) FQDN() string { } func (n *MessageDependencyGraphNode) FQDN() string { + if n == nil { + return "" + } return fmt.Sprintf("%s_%s", n.BaseMessage.FQDN(), n.VariableDefinition.Name) } diff --git a/resolver/message.go b/resolver/message.go index ba60c08f..c4da5079 100644 --- a/resolver/message.go +++ b/resolver/message.go @@ -74,6 +74,9 @@ func newMessageArgument(msg *Message) *Message { } func (m *Message) ParentMessageNames() []string { + if m == nil { + return nil + } if m.ParentMessage == nil { return []string{} } @@ -81,6 +84,9 @@ func (m *Message) ParentMessageNames() []string { } func (m *Message) Package() *Package { + if m == nil { + return nil + } if m.File == nil { return nil } @@ -88,6 +94,9 @@ func (m *Message) Package() *Package { } func (m *Message) HasRule() bool { + if m == nil { + return false + } if m.Rule == nil { return false } @@ -111,6 +120,9 @@ func (m *Message) IsEnumSelector() bool { } func (m *Message) HasResolvers() bool { + if m == nil { + return false + } if m.Rule == nil { return false } @@ -135,6 +147,9 @@ func (m *Message) HasResolvers() bool { } func (m *Message) VariableDefinitionGroups() []VariableDefinitionGroup { + if m == nil { + return nil + } if m.Rule == nil { return nil } @@ -168,6 +183,10 @@ func (m *Message) VariableDefinitionGroups() []VariableDefinitionGroup { } func (m *Message) AllVariableDefinitions() VariableDefinitions { + if m == nil { + return nil + } + var defs VariableDefinitions for _, group := range m.VariableDefinitionGroups() { defs = append(defs, group.VariableDefinitions()...) @@ -176,6 +195,9 @@ func (m *Message) AllVariableDefinitions() VariableDefinitions { } func (m *Message) HasCELValue() bool { + if m == nil { + return false + } if m.Rule == nil { return false } @@ -224,10 +246,18 @@ func (m *Message) HasCELValue() bool { } func (m *Message) HasCustomResolver() bool { + if m == nil { + return false + } + return m.Rule != nil && m.Rule.CustomResolver } func (m *Message) HasRuleEveryFields() bool { + if m == nil { + return false + } + for _, field := range m.Fields { if !field.HasRule() { return false @@ -237,10 +267,18 @@ func (m *Message) HasRuleEveryFields() bool { } func (m *Message) HasCustomResolverFields() bool { + if m == nil { + return false + } + return len(m.CustomResolverFields()) != 0 } func (m *Message) UseAllNameReference() { + if m == nil { + return + } + if m.Rule == nil { return } @@ -257,6 +295,10 @@ func (m *Message) UseAllNameReference() { } func (e *MessageExpr) ReferenceNames() []string { + if e == nil { + return nil + } + var refNames []string for _, arg := range e.Args { refNames = append(refNames, arg.Value.ReferenceNames()...) @@ -265,6 +307,10 @@ func (e *MessageExpr) ReferenceNames() []string { } func (m *Message) ReferenceNames() []string { + if m == nil { + return nil + } + if m.Rule == nil { return nil } @@ -284,6 +330,10 @@ func (m *Message) ReferenceNames() []string { } func (m *Message) CustomResolverFields() []*Field { + if m == nil { + return nil + } + fields := make([]*Field, 0, len(m.Fields)) for _, field := range m.Fields { if field.HasCustomResolver() { @@ -294,6 +344,9 @@ func (m *Message) CustomResolverFields() []*Field { } func (m *Message) GoPackage() *GoPackage { + if m == nil { + return nil + } if m.File == nil { return nil } @@ -301,6 +354,9 @@ func (m *Message) GoPackage() *GoPackage { } func (m *Message) PackageName() string { + if m == nil { + return "" + } pkg := m.Package() if pkg == nil { return "" @@ -309,6 +365,9 @@ func (m *Message) PackageName() string { } func (m *Message) FileName() string { + if m == nil { + return "" + } if m.File == nil { return "" } @@ -316,10 +375,18 @@ func (m *Message) FileName() string { } func (m *Message) HasField(name string) bool { + if m == nil { + return false + } + return m.Field(name) != nil } func (m *Message) Field(name string) *Field { + if m == nil { + return nil + } + for _, field := range m.Fields { if field.Name == name { return field @@ -329,6 +396,10 @@ func (m *Message) Field(name string) *Field { } func (m *Message) Oneof(name string) *Oneof { + if m == nil { + return nil + } + for _, field := range m.Fields { if field.Oneof == nil { continue @@ -341,6 +412,10 @@ func (m *Message) Oneof(name string) *Oneof { } func (m *Message) AllMessages() []*Message { + if m == nil { + return nil + } + ret := []*Message{m} for _, msg := range m.NestedMessages { ret = append(ret, msg.AllMessages()...) @@ -349,6 +424,10 @@ func (m *Message) AllMessages() []*Message { } func (m *Message) AllEnums() []*Enum { + if m == nil { + return nil + } + enums := m.Enums for _, msg := range m.NestedMessages { enums = append(enums, msg.AllEnums()...) @@ -357,6 +436,10 @@ func (m *Message) AllEnums() []*Enum { } func (m *Message) HasFieldRule() bool { + if m == nil { + return false + } + for _, field := range m.Fields { if field.HasRule() { return true @@ -366,6 +449,9 @@ func (m *Message) HasFieldRule() bool { } func (m *Message) DependencyGraphTreeFormat() string { + if m == nil { + return "" + } if m.Rule == nil { return "" } @@ -373,6 +459,10 @@ func (m *Message) DependencyGraphTreeFormat() string { } func (m *Message) TypeConversionDecls() []*TypeConversionDecl { + if m == nil { + return nil + } + convertedFQDNMap := make(map[string]struct{}) var decls []*TypeConversionDecl for _, def := range m.AllVariableDefinitions() { @@ -438,6 +528,10 @@ func (m *Message) TypeConversionDecls() []*TypeConversionDecl { } func (m *Message) CustomResolvers() []*CustomResolver { + if m == nil { + return nil + } + var ret []*CustomResolver if m.HasCustomResolver() { ret = append(ret, &CustomResolver{Message: m}) @@ -459,6 +553,10 @@ func (m *Message) CustomResolvers() []*CustomResolver { } func (m *Message) customResolvers(def *VariableDefinition) []*CustomResolver { + if m == nil { + return nil + } + var ret []*CustomResolver if def != nil { for _, expr := range def.MessageExprs() { @@ -472,6 +570,10 @@ func (m *Message) customResolvers(def *VariableDefinition) []*CustomResolver { } func (m *Message) GoPackageDependencies() []*GoPackage { + if m == nil { + return nil + } + pkgMap := map[*GoPackage]struct{}{} gopkg := m.GoPackage() pkgMap[gopkg] = struct{}{} @@ -524,6 +626,10 @@ func (m *Message) DependServices() []*Service { } func (m *Message) dependServices(defMap map[*VariableDefinition]struct{}) []*Service { + if m == nil { + return nil + } + var svcs []*Service for _, def := range m.AllVariableDefinitions() { svcs = append(svcs, dependServicesByDefinition(def, defMap)...) diff --git a/resolver/resolver.go b/resolver/resolver.go index 0195d41b..41a05e76 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -603,12 +603,13 @@ func (r *Resolver) lookupPackageNameMapFromVariableDefinitionSet(defSet *Variabl } case v.Expr.Map != nil: if v.Expr.Map.Expr != nil { - if v.Expr.Map.Expr.By != nil { - maps.Copy(pkgNameMap, r.lookupPackageNameMapFromCELValue(v.Expr.Map.Expr.By)) + expr := v.Expr.Map.Expr + if expr.By != nil { + maps.Copy(pkgNameMap, r.lookupPackageNameMapFromCELValue(expr.By)) } - if v.Expr.Map.Expr.Message != nil { - pkgNameMap[v.Expr.Map.Expr.Message.Message.PackageName()] = struct{}{} - maps.Copy(pkgNameMap, r.lookupPackageNameMapFromMessageArguments(v.Expr.Map.Expr.Message.Args)) + if expr.Message != nil && expr.Message.Message != nil { + pkgNameMap[expr.Message.Message.PackageName()] = struct{}{} + maps.Copy(pkgNameMap, r.lookupPackageNameMapFromMessageArguments(expr.Message.Args)) } } case v.Expr.Validation != nil: diff --git a/resolver/service.go b/resolver/service.go index 3c478427..18d98909 100644 --- a/resolver/service.go +++ b/resolver/service.go @@ -6,6 +6,9 @@ import ( ) func (s *Service) GoPackage() *GoPackage { + if s == nil { + return nil + } if s.File == nil { return nil } @@ -13,6 +16,9 @@ func (s *Service) GoPackage() *GoPackage { } func (s *Service) Package() *Package { + if s == nil { + return nil + } if s.File == nil { return nil } @@ -20,6 +26,9 @@ func (s *Service) Package() *Package { } func (s *Service) PackageName() string { + if s == nil { + return "" + } pkg := s.Package() if pkg == nil { return "" @@ -28,6 +37,9 @@ func (s *Service) PackageName() string { } func (s *Service) Method(name string) *Method { + if s == nil { + return nil + } for _, method := range s.Methods { if method.Name == name { return method @@ -37,6 +49,9 @@ func (s *Service) Method(name string) *Method { } func (s *Service) HasMessageInMethod(msg *Message) bool { + if s == nil { + return false + } for _, mtd := range s.Methods { if mtd.Request == msg { return true @@ -63,6 +78,9 @@ func (s *Service) HasMessageInVariables(msg *Message) bool { } func (s *Service) GoPackageDependencies() []*GoPackage { + if s == nil { + return nil + } pkgMap := map[*GoPackage]struct{}{} pkgMap[s.GoPackage()] = struct{}{} for _, dep := range s.ServiceDependencies() { @@ -81,6 +99,9 @@ type CustomResolver struct { } func (r *CustomResolver) FQDN() string { + if r == nil { + return "" + } if r.Field != nil { return fmt.Sprintf("%s.%s", r.Message.FQDN(), r.Field.Name) } @@ -88,6 +109,9 @@ func (r *CustomResolver) FQDN() string { } func (s *Service) CustomResolvers() []*CustomResolver { + if s == nil { + return nil + } resolverMap := make(map[string]*CustomResolver) for _, method := range s.Methods { for _, resolver := range method.FederationResponse().CustomResolvers() { @@ -105,10 +129,16 @@ func (s *Service) CustomResolvers() []*CustomResolver { } func (s *Service) ExistsCustomResolver() bool { + if s == nil { + return false + } return len(s.CustomResolvers()) != 0 } func (s *Service) ServiceDependencies() []*ServiceDependency { + if s == nil { + return nil + } useServices := s.UseServices() deps := make([]*ServiceDependency, 0, len(useServices)) depSvcMap := map[string]*ServiceDependency{} @@ -124,6 +154,9 @@ func (s *Service) ServiceDependencies() []*ServiceDependency { } func (s *Service) UseServices() []*Service { + if s == nil { + return nil + } svcMap := map[*Service]struct{}{} for _, method := range s.Methods { for _, svc := range method.Response.DependServices() {