Skip to content

Commit

Permalink
Handle ColumnSchema target in serialization of SequenceTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Oct 4, 2023
1 parent 96fccce commit 08f2e69
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion merlin/models/tf/transforms/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,10 @@ def compute_output_shape(self, input_shape):
def get_config(self):
"""Returns the config of the layer as a Python dictionary."""
config = super().get_config()
config["target"] = self.target
target = self.target
if isinstance(target, ColumnSchema):
target = schema_utils.schema_to_tensorflow_metadata_json(Schema([target]))
config["target"] = target

return config

Expand All @@ -193,6 +196,10 @@ def from_config(cls, config):
"""Creates layer from its config. Returning the instance."""
config = tf_utils.maybe_deserialize_keras_objects(config, ["pre", "post", "aggregation"])
config["schema"] = schema_utils.tensorflow_metadata_json_to_schema(config["schema"])
if config["target"].startswith("{"): # we have a schema
config["target"] = [
col for col in schema_utils.tensorflow_metadata_json_to_schema(config["target"])
][0]
schema = config.pop("schema")
target = config.pop("target")
return cls(schema, target, **config)
Expand Down

0 comments on commit 08f2e69

Please sign in to comment.