I built a model graph where I create histograms (tf.summary.histogram(...)
) of the model weights.
I noticed that this heavily slows down training.
Is there an easy way to only collect histogram summaries every once in a while, e.g. every 1000s step?
Here is a basic example of how it works. When building the computation graph, I add a tf.summary.histogram(...)
.
def dense(input_tensor, d, name):
with tf.variable_scope(name):
w = tf.get_variable(
name, [input_tensor.shape[1], d],
initializer=tf.random_normal_initializer(
mean=0, stddev=0.002, dtype=tf.float32))
z = tf.matmul(input_tensor, w)
tf.summary.histogram("weights", w)
return z
def make_optimizer(loss, variables, name='Adam'):
learning_step = (
tf.train.AdamOptimizer(0.001, name=name)
.minimize(loss, var_list=variables)
)
return learning_step
def get_loss(iterator):
with tf.variable_scope("example"):
input_batch = iterator.get_next()
input_features = input_batch['features']
input_target = input_batch['target']
output = dense(input_features, d=1000, name="dense")
loss = tf.reduce_mean(tf.squared_difference(output, input_target))
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope="example")
return loss, variables
def get_iterator():
dataset = ... # a tf.data.Dataset
return dataset.make_one_shot_iterator()
def train():
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
with graph.as_default():
iterator = get_iterator()
loss, variables = get_loss(iterator)
optimizer = make_optimizer(loss, variables, "Adam")
# Summaries and saver
summary_op = tf.summary.merge_all()
train_writer = tf.summary.FileWriter("../checkpoints/", graph)
coord = tf.train.Coordinator()
try:
step = 0
while step < 10000 and coord.should_stop():
_, loss_value, summary = (
sess.run(
[optimizer, loss, summary_op],
)
)
if step % 1000 == 0:
train_writer.add_summary(summary, step)
train_writer.flush()
step += 1
except KeyboardInterrupt:
logging.info('Interrupted')
coord.request_stop()
def main(unused_argv):
train()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
tf.app.run()
Comments
Post a Comment