diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 587cbaef514..f510d7415e4 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional +from typing import Any, Optional import torch @@ -26,6 +26,7 @@ SD3ConditioningField, TensorField, UIComponent, + UIType, ) from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -535,3 +536,27 @@ def invoke(self, context: InvocationContext) -> BoundingBoxOutput: # endregion + + +@invocation_output("any_output") +class AnyOutput(BaseInvocationOutput): + value: Any = OutputField(description="The output value", ui_type=UIType.Any) + + +@invocation( + "switcher", + title="Switcher", + tags=["primitives", "switcher"], + category="primitives", + version="1.0.0", +) +class SwitcherInvocation(BaseInvocation): + a: Any = InputField(description="The first input", ui_type=UIType.Any) + b: Any = InputField(description="The second input", ui_type=UIType.Any) + switch: bool = InputField( + description="Switch between the two inputs. If false, the first input is returned. If true, the second input is returned." + ) + + def invoke(self, context: InvocationContext) -> AnyOutput: + value = self.b if self.switch else self.a + return AnyOutput(value=value) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts index 835bf83af03..1eb21eb585d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -58,6 +58,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString); const isTargetAnyType = targetType.name === 'AnyField'; + const isSourceAnyType = sourceType.name === 'AnyField'; // One of these must be true for the connection to be valid return ( @@ -67,6 +68,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field isGenericCollectionToAnyCollectionOrSingleOrCollection || isCollectionToGenericCollection || isSubTypeMatch || - isTargetAnyType + isTargetAnyType || + isSourceAnyType ); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 27dba8c57c2..6b0875e1c28 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -144,7 +144,7 @@ export const parseSchema = ( const fieldType = fieldTypeOverride ?? originalFieldType; if (!fieldType) { - log.trace({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type'); + log.warn({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type'); return inputsAccumulator; } @@ -214,7 +214,7 @@ export const parseSchema = ( const fieldType = fieldTypeOverride ?? originalFieldType; if (!fieldType) { - log.trace({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type'); + log.warn({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type'); return outputsAccumulator; } @@ -269,7 +269,7 @@ const getFieldType = ( } catch (e) { const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError'; if (e instanceof FieldParseError) { - log.warn( + log.trace( { node: type, field: propertyName, @@ -282,7 +282,7 @@ const getFieldType = ( }) ); } else { - log.warn( + log.trace( { node: type, field: propertyName,