Donate. I desperately need donations to survive due to my health

Get paid by answering surveys Click here

Click here to donate

Remote/Work from Home jobs

Tensorflow: collecting tf.summary infrequently for performance optimization

I built a model graph where I create histograms (tf.summary.histogram(...)) of the model weights. I noticed that this heavily slows down training.

Is there an easy way to only collect histogram summaries every once in a while, e.g. every 1000s step?

Here is a basic example of how it works. When building the computation graph, I add a tf.summary.histogram(...).

def dense(input_tensor, d, name):
    with tf.variable_scope(name):
    w = tf.get_variable(
        name, [input_tensor.shape[1], d],
        initializer=tf.random_normal_initializer(
            mean=0, stddev=0.002, dtype=tf.float32))
    z = tf.matmul(input_tensor, w)

    tf.summary.histogram("weights", w)
    return z


def make_optimizer(loss, variables, name='Adam'):
    learning_step = (
    tf.train.AdamOptimizer(0.001, name=name)
        .minimize(loss, var_list=variables)
    )
    return learning_step


def get_loss(iterator):
    with tf.variable_scope("example"):
        input_batch = iterator.get_next()
        input_features = input_batch['features']
        input_target = input_batch['target']
        output = dense(input_features, d=1000, name="dense")
        loss = tf.reduce_mean(tf.squared_difference(output, input_target))

        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope="example")
    return loss, variables


def get_iterator():
    dataset = ...  # a tf.data.Dataset
    return dataset.make_one_shot_iterator()


def train():
    graph = tf.Graph()

    with tf.Session(graph=graph) as sess:
    with graph.as_default():
        iterator = get_iterator()

        loss, variables = get_loss(iterator)
        optimizer = make_optimizer(loss, variables, "Adam")

        # Summaries and saver
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter("../checkpoints/", graph)

    coord = tf.train.Coordinator()
    try:
        step = 0
        while step < 10000 and coord.should_stop():
            _, loss_value, summary = (
                sess.run(
                    [optimizer, loss, summary_op],
                )
            )

            if step % 1000 == 0:
                train_writer.add_summary(summary, step)
                train_writer.flush()
            step += 1

    except KeyboardInterrupt:
        logging.info('Interrupted')
        coord.request_stop()


def main(unused_argv):
    train()


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    tf.app.run()

Comments