Source code for deepctr.layers.utils

# -*- coding:utf-8 -*-
"""

Author:
    Weichen Shen,weichenswc@163.com

"""
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Flatten, Layer, Add
from tensorflow.python.ops.lookup_ops import TextFileInitializer

try:
    from tensorflow.python.ops.init_ops import Zeros, glorot_normal_initializer as glorot_normal
except ImportError:
    from tensorflow.python.ops.init_ops_v2 import Zeros, glorot_normal

from tensorflow.python.keras.regularizers import l2

try:
    from tensorflow.python.ops.lookup_ops import StaticHashTable
except ImportError:
    from tensorflow.python.ops.lookup_ops import HashTable as StaticHashTable


[docs]class NoMask(Layer): def __init__(self, **kwargs): super(NoMask, self).__init__(**kwargs)
[docs] def build(self, input_shape): # Be sure to call this somewhere! super(NoMask, self).build(input_shape)
[docs] def call(self, x, mask=None, **kwargs): return x
[docs] def compute_mask(self, inputs, mask): return None
[docs]class Hash(Layer): """Looks up keys in a table when setup `vocabulary_path`, which outputs the corresponding values. If `vocabulary_path` is not set, `Hash` will hash the input to [0,num_buckets). When `mask_zero` = True, input value `0` or `0.0` will be set to `0`, and other value will be set in range [1,num_buckets). The following snippet initializes a `Hash` with `vocabulary_path` file with the first column as keys and second column as values: * `1,emerson` * `2,lake` * `3,palmer` >>> hash = Hash( ... num_buckets=3+1, ... vocabulary_path=filename, ... default_value=0) >>> hash(tf.constant('lake')).numpy() 2 >>> hash(tf.constant('lakeemerson')).numpy() 0 Args: num_buckets: An `int` that is >= 1. The number of buckets or the vocabulary size + 1 when `vocabulary_path` is setup. mask_zero: default is False. The `Hash` value will hash input `0` or `0.0` to value `0` when the `mask_zero` is `True`. `mask_zero` is not used when `vocabulary_path` is setup. vocabulary_path: default `None`. The `CSV` text file path of the vocabulary hash, which contains two columns seperated by delimiter `comma`, the first column is the value and the second is the key. The key data type is `string`, the value data type is `int`. The path must be accessible from wherever `Hash` is initialized. default_value: default '0'. The default value if a key is missing in the table. **kwargs: Additional keyword arguments. """ def __init__(self, num_buckets, mask_zero=False, vocabulary_path=None, default_value=0, **kwargs): self.num_buckets = num_buckets self.mask_zero = mask_zero self.vocabulary_path = vocabulary_path self.default_value = default_value if self.vocabulary_path: initializer = TextFileInitializer(vocabulary_path, 'string', 1, 'int64', 0, delimiter=',') self.hash_table = StaticHashTable(initializer, default_value=self.default_value) super(Hash, self).__init__(**kwargs)
[docs] def build(self, input_shape): # Be sure to call this somewhere! super(Hash, self).build(input_shape)
[docs] def call(self, x, mask=None, **kwargs): if x.dtype != tf.string: zero = tf.as_string(tf.zeros([1], dtype=x.dtype)) x = tf.as_string(x, ) else: zero = tf.as_string(tf.zeros([1], dtype='int32')) if self.vocabulary_path: hash_x = self.hash_table.lookup(x) return hash_x num_buckets = self.num_buckets if not self.mask_zero else self.num_buckets - 1 try: hash_x = tf.string_to_hash_bucket_fast(x, num_buckets, name=None) # weak hash except AttributeError: hash_x = tf.strings.to_hash_bucket_fast(x, num_buckets, name=None) # weak hash if self.mask_zero: mask = tf.cast(tf.not_equal(x, zero), dtype='int64') hash_x = (hash_x + 1) * mask return hash_x
[docs] def compute_output_shape(self, input_shape): return input_shape
[docs] def get_config(self, ): config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, 'vocabulary_path': self.vocabulary_path, 'default_value': self.default_value} base_config = super(Hash, self).get_config() return dict(list(base_config.items()) + list(config.items()))
[docs]class Linear(Layer): def __init__(self, l2_reg=0.0, mode=0, use_bias=False, seed=1024, **kwargs): self.l2_reg = l2_reg # self.l2_reg = tf.contrib.layers.l2_regularizer(float(l2_reg_linear)) if mode not in [0, 1, 2]: raise ValueError("mode must be 0,1 or 2") self.mode = mode self.use_bias = use_bias self.seed = seed super(Linear, self).__init__(**kwargs)
[docs] def build(self, input_shape): if self.use_bias: self.bias = self.add_weight(name='linear_bias', shape=(1,), initializer=Zeros(), trainable=True) if self.mode == 1: self.kernel = self.add_weight( 'linear_kernel', shape=[int(input_shape[-1]), 1], initializer=glorot_normal(self.seed), regularizer=l2(self.l2_reg), trainable=True) elif self.mode == 2: self.kernel = self.add_weight( 'linear_kernel', shape=[int(input_shape[1][-1]), 1], initializer=glorot_normal(self.seed), regularizer=l2(self.l2_reg), trainable=True) super(Linear, self).build(input_shape) # Be sure to call this somewhere!
[docs] def call(self, inputs, **kwargs): if self.mode == 0: sparse_input = inputs linear_logit = reduce_sum(sparse_input, axis=-1, keep_dims=True) elif self.mode == 1: dense_input = inputs fc = tf.tensordot(dense_input, self.kernel, axes=(-1, 0)) linear_logit = fc else: sparse_input, dense_input = inputs fc = tf.tensordot(dense_input, self.kernel, axes=(-1, 0)) linear_logit = reduce_sum(sparse_input, axis=-1, keep_dims=False) + fc if self.use_bias: linear_logit += self.bias return linear_logit
[docs] def compute_output_shape(self, input_shape): return (None, 1)
[docs] def compute_mask(self, inputs, mask): return None
[docs] def get_config(self, ): config = {'mode': self.mode, 'l2_reg': self.l2_reg, 'use_bias': self.use_bias, 'seed': self.seed} base_config = super(Linear, self).get_config() return dict(list(base_config.items()) + list(config.items()))
[docs]class Concat(Layer): def __init__(self, axis, supports_masking=True, **kwargs): super(Concat, self).__init__(**kwargs) self.axis = axis self.supports_masking = supports_masking
[docs] def call(self, inputs): return tf.concat(inputs, axis=self.axis)
[docs] def compute_mask(self, inputs, mask=None): if not self.supports_masking: return None if mask is None: mask = [inputs_i._keras_mask if hasattr(inputs_i, "_keras_mask") else None for inputs_i in inputs] if mask is None: return None if not isinstance(mask, list): raise ValueError('`mask` should be a list.') if not isinstance(inputs, list): raise ValueError('`inputs` should be a list.') if len(mask) != len(inputs): raise ValueError('The lists `inputs` and `mask` ' 'should have the same length.') if all([m is None for m in mask]): return None # Make a list of masks while making sure # the dimensionality of each mask # is the same as the corresponding input. masks = [] for input_i, mask_i in zip(inputs, mask): if mask_i is None: # Input is unmasked. Append all 1s to masks, masks.append(tf.ones_like(input_i, dtype='bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it masks.append(tf.expand_dims(mask_i, axis=-1)) else: masks.append(mask_i) concatenated = K.concatenate(masks, axis=self.axis) return K.all(concatenated, axis=-1, keepdims=False)
[docs] def get_config(self, ): config = {'axis': self.axis, 'supports_masking': self.supports_masking} base_config = super(Concat, self).get_config() return dict(list(base_config.items()) + list(config.items()))
[docs]def concat_func(inputs, axis=-1, mask=False): if len(inputs) == 1: input = inputs[0] if not mask: input = NoMask()(input) return input return Concat(axis, supports_masking=mask)(inputs)
[docs]def reduce_mean(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None): try: return tf.reduce_mean(input_tensor, axis=axis, keep_dims=keep_dims, name=name, reduction_indices=reduction_indices) except TypeError: return tf.reduce_mean(input_tensor, axis=axis, keepdims=keep_dims, name=name)
[docs]def reduce_sum(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None): try: return tf.reduce_sum(input_tensor, axis=axis, keep_dims=keep_dims, name=name, reduction_indices=reduction_indices) except TypeError: return tf.reduce_sum(input_tensor, axis=axis, keepdims=keep_dims, name=name)
[docs]def reduce_max(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None): try: return tf.reduce_max(input_tensor, axis=axis, keep_dims=keep_dims, name=name, reduction_indices=reduction_indices) except TypeError: return tf.reduce_max(input_tensor, axis=axis, keepdims=keep_dims, name=name)
[docs]def div(x, y, name=None): try: return tf.div(x, y, name=name) except AttributeError: return tf.divide(x, y, name=name)
[docs]def softmax(logits, dim=-1, name=None): try: return tf.nn.softmax(logits, dim=dim, name=name) except TypeError: return tf.nn.softmax(logits, axis=dim, name=name)
class _Add(Layer): def __init__(self, **kwargs): super(_Add, self).__init__(**kwargs) def build(self, input_shape): # Be sure to call this somewhere! super(_Add, self).build(input_shape) def call(self, inputs, **kwargs): if len(inputs) == 0: return tf.constant([[0.0]]) return Add()(inputs)
[docs]def add_func(inputs): if not isinstance(inputs, list): return inputs if len(inputs) == 1: return inputs[0] return _Add()(inputs)
[docs]def combined_dnn_input(sparse_embedding_list, dense_value_list): if len(sparse_embedding_list) > 0 and len(dense_value_list) > 0: sparse_dnn_input = Flatten()(concat_func(sparse_embedding_list)) dense_dnn_input = Flatten()(concat_func(dense_value_list)) return concat_func([sparse_dnn_input, dense_dnn_input]) elif len(sparse_embedding_list) > 0: return Flatten()(concat_func(sparse_embedding_list)) elif len(dense_value_list) > 0: return Flatten()(concat_func(dense_value_list)) else: raise NotImplementedError("dnn_feature_columns can not be empty list")