diff --git a/pkg/cdi/cache.go b/pkg/cdi/cache.go index c2f7fe34..f1e4fd02 100644 --- a/pkg/cdi/cache.go +++ b/pkg/cdi/cache.go @@ -28,6 +28,7 @@ import ( "github.com/fsnotify/fsnotify" oci "github.com/opencontainers/runtime-spec/specs-go" + "tags.cncf.io/container-device-interface/pkg/cdi/producer" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -280,30 +281,31 @@ func (c *Cache) highestPrioritySpecDir() (string, int) { // priority Spec directory. If name has a "json" or "yaml" extension it // choses the encoding. Otherwise the default YAML encoding is used. func (c *Cache) WriteSpec(raw *cdi.Spec, name string) error { - var ( - specDir string - path string - prio int - spec *Spec - err error - ) - - specDir, prio = c.highestPrioritySpecDir() + specDir, _ := c.highestPrioritySpecDir() if specDir == "" { return errors.New("no Spec directories to write to") } - path = filepath.Join(specDir, name) - if ext := filepath.Ext(path); ext != ".json" && ext != ".yaml" { - path += defaultSpecExt + // Ideally we would like to pass the configured spec validator to the + // producer, but we would need to handle the synchronisation. + // Instead we call `validateSpec` here which is a no-op if no validator is + // configured. + if err := validateSpec(raw); err != nil { + return err } - spec, err = newSpec(raw, path, prio) + p, err := producer.New( + producer.WithOverwrite(true), + ) if err != nil { return err } - return spec.write(true) + path := filepath.Join(specDir, name) + if _, err := p.SaveSpec(raw, path); err != nil { + return err + } + return nil } // RemoveSpec removes a Spec with the given name from the highest diff --git a/pkg/cdi/container-edits.go b/pkg/cdi/container-edits.go index a7ac70d0..f868cc74 100644 --- a/pkg/cdi/container-edits.go +++ b/pkg/cdi/container-edits.go @@ -26,6 +26,7 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" ocigen "github.com/opencontainers/runtime-tools/generate" + "tags.cncf.io/container-device-interface/pkg/cdi/producer" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -167,32 +168,7 @@ func (e *ContainerEdits) Validate() error { if e == nil || e.ContainerEdits == nil { return nil } - - if err := ValidateEnv(e.Env); err != nil { - return fmt.Errorf("invalid container edits: %w", err) - } - for _, d := range e.DeviceNodes { - if err := (&DeviceNode{d}).Validate(); err != nil { - return err - } - } - for _, h := range e.Hooks { - if err := (&Hook{h}).Validate(); err != nil { - return err - } - } - for _, m := range e.Mounts { - if err := (&Mount{m}).Validate(); err != nil { - return err - } - } - if e.IntelRdt != nil { - if err := (&IntelRdt{e.IntelRdt}).Validate(); err != nil { - return err - } - } - - return nil + return producer.DefaultValidator.Validate(e.ContainerEdits) } // Append other edits into this one. If called with a nil receiver, @@ -220,43 +196,6 @@ func (e *ContainerEdits) Append(o *ContainerEdits) *ContainerEdits { return e } -// isEmpty returns true if these edits are empty. This is valid in a -// global Spec context but invalid in a Device context. -func (e *ContainerEdits) isEmpty() bool { - if e == nil { - return false - } - if len(e.Env) > 0 { - return false - } - if len(e.DeviceNodes) > 0 { - return false - } - if len(e.Hooks) > 0 { - return false - } - if len(e.Mounts) > 0 { - return false - } - if len(e.AdditionalGIDs) > 0 { - return false - } - if e.IntelRdt != nil { - return false - } - return true -} - -// ValidateEnv validates the given environment variables. -func ValidateEnv(env []string) error { - for _, v := range env { - if strings.IndexByte(v, byte('=')) <= 0 { - return fmt.Errorf("invalid environment variable %q", v) - } - } - return nil -} - // DeviceNode is a CDI Spec DeviceNode wrapper, used for validating DeviceNodes. type DeviceNode struct { *cdi.DeviceNode @@ -264,27 +203,7 @@ type DeviceNode struct { // Validate a CDI Spec DeviceNode. func (d *DeviceNode) Validate() error { - validTypes := map[string]struct{}{ - "": {}, - "b": {}, - "c": {}, - "u": {}, - "p": {}, - } - - if d.Path == "" { - return errors.New("invalid (empty) device path") - } - if _, ok := validTypes[d.Type]; !ok { - return fmt.Errorf("device %q: invalid type %q", d.Path, d.Type) - } - for _, bit := range d.Permissions { - if bit != 'r' && bit != 'w' && bit != 'm' { - return fmt.Errorf("device %q: invalid permissions %q", - d.Path, d.Permissions) - } - } - return nil + return producer.DefaultValidator.Validate(d.DeviceNode) } // Hook is a CDI Spec Hook wrapper, used for validating hooks. @@ -294,16 +213,7 @@ type Hook struct { // Validate a hook. func (h *Hook) Validate() error { - if _, ok := validHookNames[h.HookName]; !ok { - return fmt.Errorf("invalid hook name %q", h.HookName) - } - if h.Path == "" { - return fmt.Errorf("invalid hook %q with empty path", h.HookName) - } - if err := ValidateEnv(h.Env); err != nil { - return fmt.Errorf("invalid hook %q: %w", h.HookName, err) - } - return nil + return producer.DefaultValidator.Validate(h.Hook) } // Mount is a CDI Mount wrapper, used for validating mounts. @@ -313,13 +223,7 @@ type Mount struct { // Validate a mount. func (m *Mount) Validate() error { - if m.HostPath == "" { - return errors.New("invalid mount, empty host path") - } - if m.ContainerPath == "" { - return errors.New("invalid mount, empty container path") - } - return nil + return producer.DefaultValidator.Validate(m.Mount) } // IntelRdt is a CDI IntelRdt wrapper. @@ -337,11 +241,7 @@ func ValidateIntelRdt(i *cdi.IntelRdt) error { // Validate validates the IntelRdt configuration. func (i *IntelRdt) Validate() error { - // ClosID must be a valid Linux filename - if len(i.ClosID) >= 4096 || i.ClosID == "." || i.ClosID == ".." || strings.ContainsAny(i.ClosID, "/\n") { - return errors.New("invalid ClosID") - } - return nil + return producer.DefaultValidator.Validate(i.IntelRdt) } // Ensure OCI Spec hooks are not nil so we can add hooks. diff --git a/pkg/cdi/device.go b/pkg/cdi/device.go index 2e5fa57f..db67dbf8 100644 --- a/pkg/cdi/device.go +++ b/pkg/cdi/device.go @@ -17,10 +17,8 @@ package cdi import ( - "fmt" - oci "github.com/opencontainers/runtime-spec/specs-go" - "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/cdi/producer" "tags.cncf.io/container-device-interface/pkg/parser" cdi "tags.cncf.io/container-device-interface/specs-go" ) @@ -67,22 +65,5 @@ func (d *Device) edits() *ContainerEdits { // Validate the device. func (d *Device) validate() error { - if err := parser.ValidateDeviceName(d.Name); err != nil { - return err - } - name := d.Name - if d.spec != nil { - name = d.GetQualifiedName() - } - if err := validation.ValidateSpecAnnotations(name, d.Annotations); err != nil { - return err - } - edits := d.edits() - if edits.isEmpty() { - return fmt.Errorf("invalid device, empty device edits") - } - if err := edits.Validate(); err != nil { - return fmt.Errorf("invalid device %q: %w", d.Name, err) - } - return nil + return producer.DefaultValidator.Validate(d.Device) } diff --git a/pkg/cdi/producer/api.go b/pkg/cdi/producer/api.go index 8fbe4b1a..b2ca9816 100644 --- a/pkg/cdi/producer/api.go +++ b/pkg/cdi/producer/api.go @@ -27,3 +27,9 @@ const ( // SpecFormatYAML defines a CDI spec formatted as YAML. SpecFormatYAML = specFormat(".yaml") ) + +// Validators as constants. +const ( + DefaultValidator = defaultValidator("default") + DisabledValidator = disabledValidator("disabled") +) diff --git a/pkg/cdi/producer/options.go b/pkg/cdi/producer/options.go index 75d1746b..b83138b7 100644 --- a/pkg/cdi/producer/options.go +++ b/pkg/cdi/producer/options.go @@ -34,6 +34,17 @@ func WithSpecFormat(format specFormat) Option { } } +// WithSpecValidator sets a validator to be used when writing an output spec. +func WithSpecValidator(validator Validator) Option { + return func(p *Producer) error { + if validator == nil { + validator = DisabledValidator + } + p.validator = validator + return nil + } +} + // WithOverwrite specifies whether a producer should overwrite a CDI spec when // saving to file. func WithOverwrite(overwrite bool) Option { diff --git a/pkg/cdi/producer/producer.go b/pkg/cdi/producer/producer.go index dcecd8fb..4f9ae804 100644 --- a/pkg/cdi/producer/producer.go +++ b/pkg/cdi/producer/producer.go @@ -17,6 +17,7 @@ package producer import ( + "fmt" "path/filepath" cdi "tags.cncf.io/container-device-interface/specs-go" @@ -26,12 +27,14 @@ import ( type Producer struct { format specFormat failIfExists bool + validator Validator } // New creates a new producer with the supplied options. func New(opts ...Option) (*Producer, error) { p := &Producer{ - format: DefaultSpecFormat, + format: DefaultSpecFormat, + validator: DefaultValidator, } for _, opt := range opts { err := opt(p) @@ -47,8 +50,11 @@ func New(opts ...Option) (*Producer, error) { // extension takes precedence over the format with which the Producer was // configured. func (p *Producer) SaveSpec(s *cdi.Spec, filename string) (string, error) { - filename = p.normalizeFilename(filename) + if err := p.validator.Validate(s); err != nil { + return "", fmt.Errorf("spec validation failed: %w", err) + } + filename = p.normalizeFilename(filename) sp := spec{ Spec: s, format: p.specFormatFromFilename(filename), diff --git a/pkg/cdi/producer/validator.go b/pkg/cdi/producer/validator.go new file mode 100644 index 00000000..59c66d38 --- /dev/null +++ b/pkg/cdi/producer/validator.go @@ -0,0 +1,248 @@ +package producer + +import ( + "errors" + "fmt" + "strings" + + "tags.cncf.io/container-device-interface/internal/validation" + "tags.cncf.io/container-device-interface/pkg/parser" + cdi "tags.cncf.io/container-device-interface/specs-go" +) + +// A Validator provides and interface for validating CDI specifications. +type Validator interface { + Validate(interface{}) error + ValidateSpec(*cdi.Spec) error +} + +// A disabledValidator performs no validation. +type disabledValidator string + +// Validate always passes for a disabledValidator. +func (v disabledValidator) Validate(interface{}) error { + return nil +} + +// ValidateSpec always passes for a disabledValidator. +func (v disabledValidator) ValidateSpec(*cdi.Spec) error { + return nil +} + +type defaultValidator string + +// Validate implements a generic validation handler for the defaultValidator. +func (v defaultValidator) Validate(o interface{}) error { + switch o := o.(type) { + case *cdi.ContainerEdits: + return v.validateEdits(o) + case *cdi.Device: + return v.validateDevice("", "", o) + case *cdi.DeviceNode: + return v.validateDeviceNode(o) + case *cdi.Hook: + return v.validateHook(o) + case *cdi.IntelRdt: + return v.validateIntelRdt(o) + case *cdi.Mount: + return v.validateMount(o) + case *cdi.Spec: + return v.ValidateSpec(o) + default: + return fmt.Errorf("unsupported validation type: %T", o) + } +} + +// ValidateSpec performs a default validation. +func (v defaultValidator) ValidateSpec(s *cdi.Spec) error { + if err := cdi.ValidateVersion(s); err != nil { + return err + } + vendor, class := parser.ParseQualifier(s.Kind) + if err := parser.ValidateVendorName(vendor); err != nil { + return err + } + if err := parser.ValidateClassName(class); err != nil { + return err + } + if err := validation.ValidateSpecAnnotations(s.Kind, s.Annotations); err != nil { + return err + } + if err := v.validateEdits(&s.ContainerEdits); err != nil { + return err + } + + seen := make(map[string]bool) + for _, d := range s.Devices { + if seen[d.Name] { + return fmt.Errorf("invalid spec, multiple device %q", d.Name) + } + seen[d.Name] = true + if err := v.validateDevice(vendor, class, &d); err != nil { + return fmt.Errorf("invalid device %q: %w", d.Name, err) + } + } + return nil +} + +func (v defaultValidator) validateDevice(vendor string, class string, d *cdi.Device) error { + if err := parser.ValidateDeviceName(d.Name); err != nil { + return err + } + + name := parser.QualifiedName(vendor, class, d.Name) + if err := validation.ValidateSpecAnnotations(name, d.Annotations); err != nil { + return err + } + + if err := v.assertNonEmptyEdits(&d.ContainerEdits); err != nil { + return err + } + if err := v.validateEdits(&d.ContainerEdits); err != nil { + return err + } + return nil +} + +func (v defaultValidator) assertNonEmptyEdits(e *cdi.ContainerEdits) error { + if e == nil { + return nil + } + if len(e.Env) > 0 { + return nil + } + if len(e.DeviceNodes) > 0 { + return nil + } + if len(e.Hooks) > 0 { + return nil + } + if len(e.Mounts) > 0 { + return nil + } + if len(e.AdditionalGIDs) > 0 { + return nil + } + if e.IntelRdt != nil { + return nil + } + return errors.New("empty container edits") +} + +func (v defaultValidator) validateEdits(e *cdi.ContainerEdits) error { + if e == nil { + return nil + } + if err := v.validateEnv(e.Env); err != nil { + return fmt.Errorf("invalid container edits: %w", err) + } + for _, d := range e.DeviceNodes { + if err := v.validateDeviceNode(d); err != nil { + return err + } + } + for _, h := range e.Hooks { + if err := v.validateHook(h); err != nil { + return err + } + } + for _, m := range e.Mounts { + if err := v.validateMount(m); err != nil { + return err + } + } + if err := v.validateIntelRdt(e.IntelRdt); err != nil { + return err + } + return nil +} + +func (v defaultValidator) validateEnv(env []string) error { + for _, v := range env { + if strings.IndexByte(v, byte('=')) <= 0 { + return fmt.Errorf("invalid environment variable %q", v) + } + } + return nil +} + +func (v defaultValidator) validateDeviceNode(d *cdi.DeviceNode) error { + validTypes := map[string]struct{}{ + "": {}, + "b": {}, + "c": {}, + "u": {}, + "p": {}, + } + + if d.Path == "" { + return errors.New("invalid (empty) device path") + } + if _, ok := validTypes[d.Type]; !ok { + return fmt.Errorf("device %q: invalid type %q", d.Path, d.Type) + } + for _, bit := range d.Permissions { + if bit != 'r' && bit != 'w' && bit != 'm' { + return fmt.Errorf("device %q: invalid permissions %q", + d.Path, d.Permissions) + } + } + return nil +} + +func (v defaultValidator) validateHook(h *cdi.Hook) error { + const ( + // PrestartHook is the name of the OCI "prestart" hook. + PrestartHook = "prestart" + // CreateRuntimeHook is the name of the OCI "createRuntime" hook. + CreateRuntimeHook = "createRuntime" + // CreateContainerHook is the name of the OCI "createContainer" hook. + CreateContainerHook = "createContainer" + // StartContainerHook is the name of the OCI "startContainer" hook. + StartContainerHook = "startContainer" + // PoststartHook is the name of the OCI "poststart" hook. + PoststartHook = "poststart" + // PoststopHook is the name of the OCI "poststop" hook. + PoststopHook = "poststop" + ) + validHookNames := map[string]struct{}{ + PrestartHook: {}, + CreateRuntimeHook: {}, + CreateContainerHook: {}, + StartContainerHook: {}, + PoststartHook: {}, + PoststopHook: {}, + } + + if _, ok := validHookNames[h.HookName]; !ok { + return fmt.Errorf("invalid hook name %q", h.HookName) + } + if h.Path == "" { + return fmt.Errorf("invalid hook %q with empty path", h.HookName) + } + if err := v.validateEnv(h.Env); err != nil { + return fmt.Errorf("invalid hook %q: %w", h.HookName, err) + } + return nil +} + +func (v defaultValidator) validateMount(m *cdi.Mount) error { + if m.HostPath == "" { + return errors.New("invalid mount, empty host path") + } + if m.ContainerPath == "" { + return errors.New("invalid mount, empty container path") + } + return nil +} + +func (v defaultValidator) validateIntelRdt(i *cdi.IntelRdt) error { + if i == nil { + return nil + } + // ClosID must be a valid Linux filename + if len(i.ClosID) >= 4096 || i.ClosID == "." || i.ClosID == ".." || strings.ContainsAny(i.ClosID, "/\n") { + return errors.New("invalid ClosID") + } + return nil +} diff --git a/pkg/cdi/spec.go b/pkg/cdi/spec.go index d617046a..b0708b2c 100644 --- a/pkg/cdi/spec.go +++ b/pkg/cdi/spec.go @@ -26,7 +26,6 @@ import ( oci "github.com/opencontainers/runtime-spec/specs-go" "sigs.k8s.io/yaml" - "tags.cncf.io/container-device-interface/internal/validation" "tags.cncf.io/container-device-interface/pkg/cdi/producer" "tags.cncf.io/container-device-interface/pkg/parser" cdi "tags.cncf.io/container-device-interface/specs-go" @@ -176,22 +175,12 @@ func MinimumRequiredVersion(spec *cdi.Spec) (string, error) { // Validate the Spec. func (s *Spec) validate() (map[string]*Device, error) { - if err := cdi.ValidateVersion(s.Spec); err != nil { - return nil, err - } - if err := parser.ValidateVendorName(s.vendor); err != nil { - return nil, err - } - if err := parser.ValidateClassName(s.class); err != nil { - return nil, err - } - if err := validation.ValidateSpecAnnotations(s.Kind, s.Annotations); err != nil { - return nil, err - } - if err := s.edits().Validate(); err != nil { + if err := producer.DefaultValidator.Validate(s.Spec); err != nil { return nil, err } + // TODO: The validator above should perform the same validation as below but + // we still need to construct the device map. devices := make(map[string]*Device) for _, d := range s.Devices { dev, err := newDevice(s, d)