diff --git a/schema/object.go b/schema/object.go index 99a5d12..3c347ba 100644 --- a/schema/object.go +++ b/schema/object.go @@ -338,10 +338,16 @@ func (o *ObjectSchema) validateStruct(data any) error { return o.validateFieldInterdependencies(rawData) } -func (o *ObjectSchema) convertToObjectSchema(typeOrData any) (*ObjectSchema, bool) { - schemaType, ok := typeOrData.(*ObjectSchema) +func (o *ObjectSchema) convertToObjectSchema(typeOrData any) (Object, bool) { + // Try plain object schema + objectSchemaType, ok := typeOrData.(*ObjectSchema) if ok { - return schemaType, true + return objectSchemaType, true + } + // Next, try ref schema + refSchemaType, ok := typeOrData.(*RefSchema) + if ok { + return refSchemaType.referencedObjectCache, true } // Try getting the inlined ObjectSchema for objects, like TypedObjectSchema, that do that. value := reflect.ValueOf(typeOrData) @@ -351,15 +357,15 @@ func (o *ObjectSchema) convertToObjectSchema(typeOrData any) (*ObjectSchema, boo fieldAsInterface := field.Interface() objectType, ok2 := fieldAsInterface.(ObjectSchema) if ok2 { - schemaType = &objectType + objectSchemaType = &objectType ok = true } } } - return schemaType, ok + return objectSchemaType, ok } -func (o *ObjectSchema) validateSchemaCompatibility(schemaType *ObjectSchema) error { +func (o *ObjectSchema) validateSchemaCompatibility(schemaType Object) error { fieldData := map[string]any{} // Validate IDs. This is important because the IDs should match. if schemaType.ID() != o.ID() { diff --git a/schema/object_test.go b/schema/object_test.go index 5cf2c55..d4adb41 100644 --- a/schema/object_test.go +++ b/schema/object_test.go @@ -426,10 +426,18 @@ func TestTypedObjectSchema_Any(t *testing.T) { assert.Error(t, err) } +var testStructScope = schema.NewScopeSchema(&testStructSchema.ObjectSchema) + func TestObjectSchema_ValidateCompatibility(t *testing.T) { // Schema validation assert.NoError(t, testStructSchema.ValidateCompatibility(testStructSchema)) assert.Error(t, testStructSchema.ValidateCompatibility(testOptionalFieldSchema)) // Not the same + // Schema validation with ref + objectTestRef := schema.NewRefSchema("testStruct", nil) + objectTestRef.ApplyScope(testStructScope) + assert.NoError(t, objectTestRef.ValidateCompatibility(testStructSchema)) + assert.NoError(t, testStructSchema.ValidateCompatibility(objectTestRef)) + // map verification validData := map[string]any{ "Field1": 42,