diff --git a/pkg/config/bee.go b/pkg/config/bee.go index 4c3c58d1..28e6c6e1 100644 --- a/pkg/config/bee.go +++ b/pkg/config/bee.go @@ -7,6 +7,10 @@ import ( "github.com/ethersphere/beekeeper/pkg/orchestration" ) +type Inheritable interface { + GetParentName() string +} + // BeeConfig represents Bee configuration type BeeConfig struct { // parent to inherit settings from @@ -56,6 +60,13 @@ type BeeConfig struct { WithdrawAddress *string `yaml:"withdrawal-addresses-whitelist"` } +func (b BeeConfig) GetParentName() string { + if b.Inherit != nil { + return b.Inherit.ParentName + } + return "" +} + // Export exports BeeConfig to orchestration.Config func (b *BeeConfig) Export() (o orchestration.Config) { localVal := reflect.ValueOf(b).Elem() diff --git a/pkg/config/cluster.go b/pkg/config/cluster.go index c5621ed5..536f3972 100644 --- a/pkg/config/cluster.go +++ b/pkg/config/cluster.go @@ -22,6 +22,13 @@ type Cluster struct { NodeGroups *map[string]ClusterNodeGroup `yaml:"node-groups"` } +func (b Cluster) GetParentName() string { + if b.Inherit != nil { + return b.Inherit.ParentName + } + return "" +} + // ClusterNodeGroup represents node group in the cluster type ClusterNodeGroup struct { cluster *Cluster diff --git a/pkg/config/config.go b/pkg/config/config.go index 92759b11..0ab8af08 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -32,8 +32,7 @@ func (c *Config) PrintYaml(w io.Writer) (err error) { if c == nil { return fmt.Errorf("config not initialized") } - enc := yaml.NewEncoder(w) - if err := enc.Encode(c); err != nil { + if err := yaml.NewEncoder(w).Encode(c); err != nil { return fmt.Errorf("config can not be encoded: %s", err.Error()) } return @@ -41,73 +40,82 @@ func (c *Config) PrintYaml(w io.Writer) (err error) { // merge combines Config objects using inheritance func (c *Config) merge() (err error) { - // merge BeeConfigs - mergedBC := map[string]BeeConfig{} - for name, v := range c.BeeConfigs { - if len(v.ParentName) == 0 { - mergedBC[name] = v - } else { - parent, ok := c.BeeConfigs[v.ParentName] - if !ok { - return fmt.Errorf("bee profile %s doesn't exist", v.ParentName) - } - p := reflect.ValueOf(&parent).Elem() - m := reflect.ValueOf(&v).Elem() - for i := 0; i < m.NumField(); i++ { - if m.Field(i).IsNil() && !p.Field(i).IsNil() { - m.Field(i).Set(p.Field(i)) - } - } - mergedBC[name] = m.Interface().(BeeConfig) - } + c.BeeConfigs, err = mergeConfigs(c.BeeConfigs) + if err != nil { + return fmt.Errorf("merging bee configs: %w", err) } - c.BeeConfigs = mergedBC - - // merge NodeGroups - mergedNG := map[string]NodeGroup{} - for name, v := range c.NodeGroups { - if len(v.ParentName) == 0 { - mergedNG[name] = v - } else { - parent, ok := c.NodeGroups[v.ParentName] - if !ok { - return fmt.Errorf("node group profile %s doesn't exist", v.ParentName) + + c.NodeGroups, err = mergeConfigs(c.NodeGroups) + if err != nil { + return fmt.Errorf("merging node groups: %w", err) + } + + c.Clusters, err = mergeConfigs(c.Clusters) + if err != nil { + return fmt.Errorf("merging clusters: %w", err) + } + + return +} + +func mergeConfigs[T any](configs map[string]T) (map[string]T, error) { + mergedConfigs := make(map[string]T) + visited := map[string]bool{} + + // recursively merge configs (internal function) + var mergeParent func(name string) (T, error) + mergeParent = func(name string) (T, error) { + if config, ok := mergedConfigs[name]; ok { + return config, nil // already merged + } + var zero T + + // detect circular inheritance + if visited[name] { + return zero, fmt.Errorf("circular inheritance detected with bee profile %s", name) + } + visited[name] = true + + v, ok := configs[name] + if !ok { + return zero, fmt.Errorf("bee profile %s doesn't exist", name) + } + + // check if T implements Inheritable to get parent name + vIneheritable, ok := any(v).(Inheritable) + if !ok { + return zero, fmt.Errorf("type %T does not implement Inheritable interface", v) + } + + // merge the parent + if len(vIneheritable.GetParentName()) > 0 { + parentConfig, err := mergeParent(vIneheritable.GetParentName()) + if err != nil { + return zero, err } - p := reflect.ValueOf(&parent).Elem() + + // merge parent fields into the current config + p := reflect.ValueOf(&parentConfig).Elem() m := reflect.ValueOf(&v).Elem() for i := 0; i < m.NumField(); i++ { if m.Field(i).IsNil() && !p.Field(i).IsNil() { m.Field(i).Set(p.Field(i)) } } - mergedNG[name] = m.Interface().(NodeGroup) } + + mergedConfigs[name] = v + delete(visited, name) // remove after merge + return v, nil } - c.NodeGroups = mergedNG - - // merge clusters - mergedC := map[string]Cluster{} - for name, v := range c.Clusters { - if len(v.ParentName) == 0 { - mergedC[name] = v - } else { - parent, ok := c.Clusters[v.ParentName] - if !ok { - return fmt.Errorf("bee profile %s doesn't exist", v.ParentName) - } - p := reflect.ValueOf(&parent).Elem() - m := reflect.ValueOf(&v).Elem() - for i := 0; i < m.NumField(); i++ { - if m.Field(i).IsNil() && !p.Field(i).IsNil() { - m.Field(i).Set(p.Field(i)) - } - } - mergedC[name] = m.Interface().(Cluster) + + for name := range configs { + if _, err := mergeParent(name); err != nil { + return nil, err } } - c.Clusters = mergedC - return + return mergedConfigs, nil } // Read reads given YAML files and unmarshals them into Config diff --git a/pkg/config/nodegroup.go b/pkg/config/nodegroup.go index 8c74df4f..3ea9b2a4 100644 --- a/pkg/config/nodegroup.go +++ b/pkg/config/nodegroup.go @@ -31,6 +31,13 @@ type NodeGroup struct { UpdateStrategy *string `yaml:"update-strategy"` } +func (b NodeGroup) GetParentName() string { + if b.Inherit != nil { + return b.Inherit.ParentName + } + return "" +} + // Export exports NodeGroup to orchestration.NodeGroupOptions func (n *NodeGroup) Export() (o orchestration.NodeGroupOptions) { localVal := reflect.ValueOf(n).Elem()