# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,weichenswc@163.com
"""
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Flatten, Layer, Add
from tensorflow.python.ops.lookup_ops import TextFileInitializer
try:
from tensorflow.python.ops.init_ops import Zeros, glorot_normal_initializer as glorot_normal
except ImportError:
from tensorflow.python.ops.init_ops_v2 import Zeros, glorot_normal
from tensorflow.python.keras.regularizers import l2
try:
from tensorflow.python.ops.lookup_ops import StaticHashTable
except ImportError:
from tensorflow.python.ops.lookup_ops import HashTable as StaticHashTable
[docs]class NoMask(Layer):
def __init__(self, **kwargs):
super(NoMask, self).__init__(**kwargs)
[docs] def build(self, input_shape):
# Be sure to call this somewhere!
super(NoMask, self).build(input_shape)
[docs] def call(self, x, mask=None, **kwargs):
return x
[docs] def compute_mask(self, inputs, mask):
return None
[docs]class Hash(Layer):
"""Looks up keys in a table when setup `vocabulary_path`, which outputs the corresponding values.
If `vocabulary_path` is not set, `Hash` will hash the input to [0,num_buckets). When `mask_zero` = True,
input value `0` or `0.0` will be set to `0`, and other value will be set in range [1,num_buckets).
The following snippet initializes a `Hash` with `vocabulary_path` file with the first column as keys and
second column as values:
* `1,emerson`
* `2,lake`
* `3,palmer`
>>> hash = Hash(
... num_buckets=3+1,
... vocabulary_path=filename,
... default_value=0)
>>> hash(tf.constant('lake')).numpy()
2
>>> hash(tf.constant('lakeemerson')).numpy()
0
Args:
num_buckets: An `int` that is >= 1. The number of buckets or the vocabulary size + 1
when `vocabulary_path` is setup.
mask_zero: default is False. The `Hash` value will hash input `0` or `0.0` to value `0` when
the `mask_zero` is `True`. `mask_zero` is not used when `vocabulary_path` is setup.
vocabulary_path: default `None`. The `CSV` text file path of the vocabulary hash, which contains
two columns seperated by delimiter `comma`, the first column is the value and the second is
the key. The key data type is `string`, the value data type is `int`. The path must
be accessible from wherever `Hash` is initialized.
default_value: default '0'. The default value if a key is missing in the table.
**kwargs: Additional keyword arguments.
"""
def __init__(self, num_buckets, mask_zero=False, vocabulary_path=None, default_value=0, **kwargs):
self.num_buckets = num_buckets
self.mask_zero = mask_zero
self.vocabulary_path = vocabulary_path
self.default_value = default_value
if self.vocabulary_path:
initializer = TextFileInitializer(vocabulary_path, 'string', 1, 'int64', 0, delimiter=',')
self.hash_table = StaticHashTable(initializer, default_value=self.default_value)
super(Hash, self).__init__(**kwargs)
[docs] def build(self, input_shape):
# Be sure to call this somewhere!
super(Hash, self).build(input_shape)
[docs] def call(self, x, mask=None, **kwargs):
if x.dtype != tf.string:
zero = tf.as_string(tf.zeros([1], dtype=x.dtype))
x = tf.as_string(x, )
else:
zero = tf.as_string(tf.zeros([1], dtype='int32'))
if self.vocabulary_path:
hash_x = self.hash_table.lookup(x)
return hash_x
num_buckets = self.num_buckets if not self.mask_zero else self.num_buckets - 1
try:
hash_x = tf.string_to_hash_bucket_fast(x, num_buckets,
name=None) # weak hash
except AttributeError:
hash_x = tf.strings.to_hash_bucket_fast(x, num_buckets,
name=None) # weak hash
if self.mask_zero:
mask = tf.cast(tf.not_equal(x, zero), dtype='int64')
hash_x = (hash_x + 1) * mask
return hash_x
[docs] def compute_output_shape(self, input_shape):
return input_shape
[docs] def get_config(self, ):
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, 'vocabulary_path': self.vocabulary_path,
'default_value': self.default_value}
base_config = super(Hash, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs]class Linear(Layer):
def __init__(self, l2_reg=0.0, mode=0, use_bias=False, seed=1024, **kwargs):
self.l2_reg = l2_reg
# self.l2_reg = tf.contrib.layers.l2_regularizer(float(l2_reg_linear))
if mode not in [0, 1, 2]:
raise ValueError("mode must be 0,1 or 2")
self.mode = mode
self.use_bias = use_bias
self.seed = seed
super(Linear, self).__init__(**kwargs)
[docs] def build(self, input_shape):
if self.use_bias:
self.bias = self.add_weight(name='linear_bias',
shape=(1,),
initializer=Zeros(),
trainable=True)
if self.mode == 1:
self.kernel = self.add_weight(
'linear_kernel',
shape=[int(input_shape[-1]), 1],
initializer=glorot_normal(self.seed),
regularizer=l2(self.l2_reg),
trainable=True)
elif self.mode == 2:
self.kernel = self.add_weight(
'linear_kernel',
shape=[int(input_shape[1][-1]), 1],
initializer=glorot_normal(self.seed),
regularizer=l2(self.l2_reg),
trainable=True)
super(Linear, self).build(input_shape) # Be sure to call this somewhere!
[docs] def call(self, inputs, **kwargs):
if self.mode == 0:
sparse_input = inputs
linear_logit = reduce_sum(sparse_input, axis=-1, keep_dims=True)
elif self.mode == 1:
dense_input = inputs
fc = tf.tensordot(dense_input, self.kernel, axes=(-1, 0))
linear_logit = fc
else:
sparse_input, dense_input = inputs
fc = tf.tensordot(dense_input, self.kernel, axes=(-1, 0))
linear_logit = reduce_sum(sparse_input, axis=-1, keep_dims=False) + fc
if self.use_bias:
linear_logit += self.bias
return linear_logit
[docs] def compute_output_shape(self, input_shape):
return (None, 1)
[docs] def compute_mask(self, inputs, mask):
return None
[docs] def get_config(self, ):
config = {'mode': self.mode, 'l2_reg': self.l2_reg, 'use_bias': self.use_bias, 'seed': self.seed}
base_config = super(Linear, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs]class Concat(Layer):
def __init__(self, axis, supports_masking=True, **kwargs):
super(Concat, self).__init__(**kwargs)
self.axis = axis
self.supports_masking = supports_masking
[docs] def call(self, inputs):
return tf.concat(inputs, axis=self.axis)
[docs] def compute_mask(self, inputs, mask=None):
if not self.supports_masking:
return None
if mask is None:
mask = [inputs_i._keras_mask if hasattr(inputs_i, "_keras_mask") else None for inputs_i in inputs]
if mask is None:
return None
if not isinstance(mask, list):
raise ValueError('`mask` should be a list.')
if not isinstance(inputs, list):
raise ValueError('`inputs` should be a list.')
if len(mask) != len(inputs):
raise ValueError('The lists `inputs` and `mask` '
'should have the same length.')
if all([m is None for m in mask]):
return None
# Make a list of masks while making sure
# the dimensionality of each mask
# is the same as the corresponding input.
masks = []
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
# Input is unmasked. Append all 1s to masks,
masks.append(tf.ones_like(input_i, dtype='bool'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
masks.append(tf.expand_dims(mask_i, axis=-1))
else:
masks.append(mask_i)
concatenated = K.concatenate(masks, axis=self.axis)
return K.all(concatenated, axis=-1, keepdims=False)
[docs] def get_config(self, ):
config = {'axis': self.axis, 'supports_masking': self.supports_masking}
base_config = super(Concat, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs]def concat_func(inputs, axis=-1, mask=False):
if len(inputs) == 1:
input = inputs[0]
if not mask:
input = NoMask()(input)
return input
return Concat(axis, supports_masking=mask)(inputs)
[docs]def reduce_mean(input_tensor,
axis=None,
keep_dims=False,
name=None,
reduction_indices=None):
try:
return tf.reduce_mean(input_tensor,
axis=axis,
keep_dims=keep_dims,
name=name,
reduction_indices=reduction_indices)
except TypeError:
return tf.reduce_mean(input_tensor,
axis=axis,
keepdims=keep_dims,
name=name)
[docs]def reduce_sum(input_tensor,
axis=None,
keep_dims=False,
name=None,
reduction_indices=None):
try:
return tf.reduce_sum(input_tensor,
axis=axis,
keep_dims=keep_dims,
name=name,
reduction_indices=reduction_indices)
except TypeError:
return tf.reduce_sum(input_tensor,
axis=axis,
keepdims=keep_dims,
name=name)
[docs]def reduce_max(input_tensor,
axis=None,
keep_dims=False,
name=None,
reduction_indices=None):
try:
return tf.reduce_max(input_tensor,
axis=axis,
keep_dims=keep_dims,
name=name,
reduction_indices=reduction_indices)
except TypeError:
return tf.reduce_max(input_tensor,
axis=axis,
keepdims=keep_dims,
name=name)
[docs]def div(x, y, name=None):
try:
return tf.div(x, y, name=name)
except AttributeError:
return tf.divide(x, y, name=name)
[docs]def softmax(logits, dim=-1, name=None):
try:
return tf.nn.softmax(logits, dim=dim, name=name)
except TypeError:
return tf.nn.softmax(logits, axis=dim, name=name)
class _Add(Layer):
def __init__(self, **kwargs):
super(_Add, self).__init__(**kwargs)
def build(self, input_shape):
# Be sure to call this somewhere!
super(_Add, self).build(input_shape)
def call(self, inputs, **kwargs):
if len(inputs) == 0:
return tf.constant([[0.0]])
return Add()(inputs)
[docs]def add_func(inputs):
if not isinstance(inputs, list):
return inputs
if len(inputs) == 1:
return inputs[0]
return _Add()(inputs)