Skip to content

Customized Feature Column

workingloong edited this page Dec 2, 2019 · 15 revisions

Customized Feature Column

TF2.0 users generally define model using Keras. tf.keras.layers.Embedding is the native embedding layer in Keras. It can only handle the dense input. For sparse inputs, user often use tf.feature_column.embedding_columns to convert them to dense representation to feed to a DNN.

For ElasticDL, users define the model using keras too. And we have provided elastic.layers.Embedding to interact with the ElasticDL parameter server and partition the embedding table among multiple PS instances. It can replace the native keras embedding layer but can't replace the embedding_column.

In this doc, we are focuing on how to write a customized embedding feature column to interact with the parameter server and how to replace the native feature column with ours.

How to write a customzied embedding feature column

  1. Define a new class inheriting from FeatureColumn. What's more, we want to customized a embedding column, so it need inhert from DenseColumn.
  2. Implement all the abstract methods. Especially we focus on the following two methods: create_state
    Create the variable for this FeatureColumn associated with the DenseFeature layer, such as the embedding variables.
    get_dense_tensor
    While executing DenseFeature.call, it will iterate all the feature column elements and call get_dense_tensor to get the transformed dense tensor from the feature columns. Let's take native EmbeddingColumn for example, it will call embedding_ops.safe_embedding_lookup_sparse to get the embedding vectors from the sparse input.
import tensorflow as tf
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
import collections
import numpy as np

def customized_embedding_column(categorical_column,
                                dimension,
                                combiner='mean',
                                initializer=None,
                                ckpt_to_load_from=None,
                                tensor_name_in_ckpt=None,
                                max_norm=None,
                                trainable=True):
    if (dimension is None) or (dimension < 1):
        raise ValueError('Invalid dimension {}.'.format(dimension))
    if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
        raise ValueError('Must specify both `ckpt_to_load_from` and '
                     '`tensor_name_in_ckpt` or none of them.')

    if (initializer is not None) and (not callable(initializer)):
        raise ValueError('initializer must be callable if specified. '
                     'Embedding of column_name: {}'.format(
                         categorical_column.name))
    if initializer is None:
        initializer = tf.keras.initializers.ones

    return CustomizedEmbeddingColumn(
        categorical_column=categorical_column,
        dimension=dimension,
        combiner=combiner,
        initializer=initializer,
        ckpt_to_load_from=ckpt_to_load_from,
        tensor_name_in_ckpt=tensor_name_in_ckpt,
        max_norm=max_norm,
        trainable=trainable,
        extended={})

EmbeddingAndIds = collections.namedtuple(
    "EmbeddingAndIds", ["batch_embedding", "batch_ids"]
)

class CustomizedEmbeddingColumn(
    fc_lib.DenseColumn,
    fc_lib.SequenceDenseColumn,
    fc_old._DenseColumn,
    fc_old._SequenceDenseColumn,
    collections.namedtuple(
        'EmbeddingColumn',
        ('categorical_column', 'dimension', 'combiner', 'initializer',
         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable', 'extended'))):

    @property
    def _is_v2_column(self):
        return (isinstance(self.categorical_column, fc_lib.FeatureColumn) and
            self.categorical_column._is_v2_column)

    @property
    def name(self):
        """See `FeatureColumn` base class."""
        return '{}_customized_embedding'.format(self.categorical_column.name)

    @property
    def parse_example_spec(self):
        """See `FeatureColumn` base class."""
        return self.categorical_column.parse_example_spec

    @property
    def variable_shape(self):
        """See `DenseColumn` base class."""
        return tensor_shape.TensorShape([self.dimension])

    def create_state(self, state_manager):
        """Creates the embedding lookup variable."""
        default_num_buckets = (self.categorical_column.num_buckets
                                if self._is_v2_column
                                else self.categorical_column._num_buckets)   # pylint: disable=protected-access
        num_buckets = getattr(self.categorical_column, 'num_buckets', default_num_buckets)
        embedding_shape = (num_buckets, self.dimension)

        if 'np_embedding_table' not in self.extended:
            self.extended['np_embedding_table'] = np.full(embedding_shape, 10)

        self.embedding_and_ids = []

    def get_dense_tensor(self, transformation_cache, state_manager):
        if isinstance(self.categorical_column, fc_lib.SequenceCategoricalColumn):
            raise ValueError(
                'In embedding_column: {}. '
                'categorical_column must not be of type SequenceCategoricalColumn. '
                'Suggested fix A: If you wish to use DenseFeatures, use a '
                'non-sequence categorical_column_with_*. '
                'Suggested fix B: If you wish to create sequence input, use '
                'SequenceFeatures instead of DenseFeatures. '
                'Given (type {}): {}'.format(self.name, type(self.categorical_column),
                                            self.categorical_column))
        # Get sparse IDs and weights.
        sparse_tensors = self.categorical_column.get_sparse_tensors(
            transformation_cache, state_manager)

        sparse_ids = sparse_tensors.id_tensor
        unique_ids, idx = tf.unique(sparse_ids.values)

        batch_embedding = tf.py_function(
            self.lookup_embedding, inp=[unique_ids], Tout=tf.float32
        )

        if self.tape:
            batch_embedding = self._record_gradients(batch_embedding, unique_ids)

        segment_ids = sparse_ids.indices[:, 0]
        if segment_ids.dtype != tf.int32:
            segment_ids = tf.cast(segment_ids, tf.int32)

        if self.combiner == "sum":
            batch_embedding = tf.sparse.segment_sum(
                batch_embedding, idx, segment_ids
            )
        elif self.combiner == "mean":
            batch_embedding = tf.sparse.segment_mean(
                batch_embedding, idx, segment_ids
            )
        elif self.combiner == "sqrtn":
            batch_embedding = tf.sparse.segment_sqrt_n(
                batch_embedding, idx, segment_ids
            )

        return batch_embedding

    def lookup_embedding(self, unique_ids):
        batch_embedding = []
        ids = unique_ids.numpy()
        for id in ids:
            batch_embedding.append(self.extended['np_embedding_table'][id])

        batch_embedding = np.concatenate(batch_embedding, axis=0)
        return batch_embedding.reshape((len(unique_ids), self.dimension))

    @property
    def tape(self):
        if 'tape' in self.extended:
            return self.extended['tape']
        return None

    @tape.setter
    def tape(self, tape):
        self.extended['tape'] = tape

    @property
    def embedding_and_ids(self):
        if 'embedding_and_ids' in self.extended:
            return self.extended['embedding_and_ids']
        return None

    @embedding_and_ids.setter
    def embedding_and_ids(self, embedding_and_ids):
        self.extended['embedding_and_ids'] = embedding_and_ids

    def _record_gradients(self, batch_embedding, ids):
        self.tape.watch(batch_embedding)
        self.embedding_and_ids.append(
                EmbeddingAndIds(batch_embedding, ids)
            )

        return batch_embedding

    def reset(self):
        self.embedding_and_ids = []
        self.tape = None

def set_tape_to_customized_embedding_columns(model, tape):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.DenseFeatures):
            for column in layer._feature_columns:
                if isinstance(column, CustomizedEmbeddingColumn):
                    column.tape = tape

def reset_customized_embedding_columns(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.DenseFeatures):
            for column in layer._feature_columns:
                if isinstance(column, CustomizedEmbeddingColumn):
                    column.reset()

def get_trainable_items(model):
    bets = []
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.DenseFeatures):
            for column in layer._feature_columns:
                if isinstance(column, CustomizedEmbeddingColumn):
                    bets.extend(
                        [
                            batch_embedding for (batch_embedding, _) in column.embedding_and_ids
                        ]
                    )

    return list(model.trainable_variables) + bets

How to replace feature column

The following sample code is about how to replace embedding_column with customized embedding column.

def replace_embedding_column_with_customized_in_feature_layer(feature_layer):
    new_feature_columns = []
    for column in feature_layer._feature_columns:
        if isinstance(column, fc_lib.EmbeddingColumn):
            new_column = customized_fc.customized_embedding_column(
                column.categorical_column,
                dimension=column.dimension)
            new_feature_columns.append(new_column)
        else:
            new_feature_columns.append(column)

    feature_layer._feature_columns = new_feature_columns

    return feature_layer

How to assign trained weights to embedding_column

The following code snippets show that how to assign trained weights to embedding_column. Using this method, We can train embedding_columns using customized lookup operator and save model using tensorflow native embedding_column for tf-serving.

Firstly, we define a Keras model with two embedding_columns.

import time
import numpy as np
import tensorflow as tf 


def get_feature_columns():
    age = tf.feature_column.numeric_column("age", dtype=tf.int64)
    edu_embedding = tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_hash_bucket(
        'education', hash_bucket_size=4),
        4
    )
    edu_embedding_1 = tf.feature_column.embedding_column(
        tf.feature_column.categorical_column_with_hash_bucket(
        'education_1', hash_bucket_size=4),
        1
    )
    return [age, edu_embedding, edu_embedding_1]


def custom_model(feature_columns):
    input_layers = {}
    input_layers['age'] = tf.keras.layers.Input(name='age', shape=(1,))
    input_layers['education'] = tf.keras.layers.Input(name='education', shape=(1,), dtype=tf.string)
    
    input_layers['education_1'] = tf.keras.layers.Input(name='education_1', shape=(1,), dtype=tf.string)
    
    dense_feature = tf.keras.layers.DenseFeatures(feature_columns=feat_cols)(input_layers)
    return tf.keras.models.Model(inputs=input_layers, outputs=dense_feature)

feat_cols = get_feature_columns()
model = custom_model(feat_cols)
output = model.call(
    {
        'age':tf.constant([[10],[16]]),
        'education':tf.constant([['Bachelors'],['Master']]),
        'education_1':tf.constant([['Bachelors'],['Master']])
    }
)
print(output)

Next, we will mock trained weights and assign the weights to embedding_columns in the model.

import numpy as np
from tensorflow.python.feature_column.feature_column_v2 import EmbeddingColumn

def mock_embedding_column_weights(feature_columns, dense_feature_name):
    embedding_column_weights = {}
    for fc in feature_columns:
        if isinstance(fc, EmbeddingColumn):
            variable_name = "/".join([dense_feature_name, fc.name, "embedding_weights:0"])
            weight_shape = (fc.categorical_column.num_buckets, fc.dimension)
            embedding_column_weights[variable_name] = np.ones(weight_shape)
    return embedding_column_weights


model_feature_columns = None
for layer in model.layers:
    if isinstance(layer, tf.keras.layers.DenseFeatures):
        model_feature_columns = list(layer._feature_columns)  
        embedding_column_weights = mock_embedding_column_weights(
            model_feature_columns, layer.name
        )

for weight in model.trainable_weights:
    if weight.name in embedding_column_weights:
        weight.assign(embedding_column_weights[weight.name])

output = model.call(
    {
        'age':tf.constant([[10],[16]]),
        'education':tf.constant([['Bachelors'],['Master']]),
        'education_1':tf.constant([['Bachelors'],['Master']]),
    }
)
print(output)