Source code for deepctr.estimator.inputs
import tensorflow as tf
[docs]def input_fn_pandas(df, features, label=None, batch_size=256, num_epochs=1, shuffle=False, queue_capacity_factor=10,
num_threads=1):
if label is not None:
y = df[label]
else:
y = None
if tf.__version__ >= "2.0.0":
return tf.compat.v1.estimator.inputs.pandas_input_fn(df[features], y, batch_size=batch_size,
num_epochs=num_epochs,
shuffle=shuffle,
queue_capacity=batch_size * queue_capacity_factor,
num_threads=num_threads)
return tf.estimator.inputs.pandas_input_fn(df[features], y, batch_size=batch_size, num_epochs=num_epochs,
shuffle=shuffle, queue_capacity=batch_size * queue_capacity_factor,
num_threads=num_threads)
[docs]def input_fn_tfrecord(filenames, feature_description, label=None, batch_size=256, num_epochs=1, num_parallel_calls=8,
shuffle_factor=10, prefetch_factor=1,
):
def _parse_examples(serial_exmp):
try:
features = tf.parse_single_example(serial_exmp, features=feature_description)
except AttributeError:
features = tf.io.parse_single_example(serial_exmp, features=feature_description)
if label is not None:
labels = features.pop(label)
return features, labels
return features
def input_fn():
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_examples, num_parallel_calls=num_parallel_calls)
if shuffle_factor > 0:
dataset = dataset.shuffle(buffer_size=batch_size * shuffle_factor)
dataset = dataset.repeat(num_epochs).batch(batch_size)
if prefetch_factor > 0:
dataset = dataset.prefetch(buffer_size=batch_size * prefetch_factor)
try:
iterator = dataset.make_one_shot_iterator()
except AttributeError:
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
return iterator.get_next()
return input_fn