# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,weichenswc@163.com
"""
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Layer
try:
from tensorflow.python.ops.init_ops import Zeros, Ones
except ImportError:
from tensorflow.python.ops.init_ops_v2 import Zeros, Ones
[docs]class LayerNormalization(Layer):
def __init__(self, axis=-1, eps=1e-9, center=True,
scale=True, **kwargs):
self.axis = axis
self.eps = eps
self.center = center
self.scale = scale
super(LayerNormalization, self).__init__(**kwargs)
[docs] def build(self, input_shape):
self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
initializer=Ones(), trainable=True)
self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
initializer=Zeros(), trainable=True)
super(LayerNormalization, self).build(input_shape)
[docs] def call(self, inputs):
mean = K.mean(inputs, axis=self.axis, keepdims=True)
variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
std = K.sqrt(variance + self.eps)
outputs = (inputs - mean) / std
if self.scale:
outputs *= self.gamma
if self.center:
outputs += self.beta
return outputs
[docs] def compute_output_shape(self, input_shape):
return input_shape
[docs] def get_config(self, ):
config = {'axis': self.axis, 'eps': self.eps, 'center': self.center, 'scale': self.scale}
base_config = super(LayerNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))