
Using TensorFlow’s Batch Normalization Correctly
Update: This guide applies to TF1. For TF2, use tf.keras.layers.BatchNormalization layer.
The TensorFlow library’s layers API contains a function for batch normalization: tf.layers.batch_normalization
. It is supposedly as easy to use as all the other tf.layers functions, however, it has some pitfalls. This post explains how to use tf.layers.batch_normalization
correctly. It does not delve into what batch normalization is, which can be looked up in the paper “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” by Ioeffe and Szegedy (2015).
Quick link: tf.layers.batch_normalization API docs
Summary
- Use the
training
parameter of thebatch_normalization
function. - Update the moving averages by evaluating the ops manually or by adding them as a control dependency.
- The final code can be found in this Jupyter notebook.
We start off by defining a simple computational graph. The input is a placeholder which takes a batch of scalar values.
x = tf.placeholder(tf.float32, [None, 1], 'x')
The input
y = tf.layers.batch_normalization(x)
With this setup we have got some basic batch normalization set up. We can create a session and feed a sample vector, here
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
y_out = sess.run(y, feed_dict={x: [[-10], [0], [10]]})
sess.close()
Surprisingly, the value of y_out
is not normalized at all. In fact it is something like
The batch normalization layer does not normalize based on the current batch if its training
parameter is not set to true. Heading back to the definition of
y = tf.layers.batch_normalization(x, training=True)
After making this change, the output for
With this setup, the batch normalization layer looks at the current batch and normalized it depending on its value. That might not be desired, at any time: Consider a case where there is only one sample in the batch, e.g.
In order to update the two moving average variables (mean and variance), which the tf.layers.batch_normalization
function call creates automatically, two operations must be evaluated while feeding a batch through the layer. The operations can be found in the collection tf.GraphKeys.UPDATE_OPS
. In the example above, tf.get_collection(tf.GraphKeys.UPDATE_OPS)
yields
[<tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>,
<tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>]
If we change the setup and evaluate the update operations alongside with the forward pass
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
y_1 = sess.run([y, update_ops], feed_dict={x: [[-10], [0], [10]]})[0]
y_2 = sess.run(y, feed_dict={x: [[-10]]})
the values for training
parameter is set to False
. We can control it with a placeholder (here a placeholder with a default value)
is_training = tf.placeholder_with_default(False, (), 'is_training')
y = tf.layers.batch_normalization(x, training=is_training)
and set it to True
when feeding the larger batch (and False
for the smaller batch; strictly not necessary because it is the placeholder’s default value anyways):
y_1 = sess.run([y, update_ops], feed_dict={x: [[-10], [0], [10]], is_training: True})[0]
y_2 = sess.run(y, feed_dict={x: [[-10]], is_training: False})
The output we get is
: Normalized as desired, moving averages were updated based on this normalization. : Kind of weird. It’s neither , which it was without moving averages, nor , which it should be if it was normalized with the same factors as the batch.
The reason for the wrong normalization of the small batch is that the moving averages update slowly. If we were to feed the larger batch multiple times, the second batch would be properly normalized:
for _ in range(1000):
y_1 = sess.run([y, update_ops], feed_dict={x: [[-10], [0], [10]], is_training: True})[0]
y_2 = sess.run(y, feed_dict={x: [[-10]], is_training: False})
Here, we feed the larger batch 1000 times and run the update operations every time. The result is momentum
parameter of tf.layers.batch_normalization
the pace of the average update can be adjusted.
Right now we have to call sess.run
and pass the update_ops
manually. It is more convenient to add them as a control dependency, such that TensorFlow always executes them if the Tensor y
is being evaluated. The new graph definition looks like that:
x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=is_training)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
y = tf.identity(y)
The forward passes are now much cleaner:
x_1 = [[-10], [0], [10]]
x_2 = [[-10]]
for _ in range(1000):
y_1 = sess.run(y, feed_dict={x: x_1, is_training: True})
y_2 = sess.run(y, feed_dict={x: x_2})
Typically, is_training
should be set to True
during training and False
when performing inference.
The values stored by the batch normalization layer can be examined. In order to so, we retrieve their names from tf.all_variables()
which outputs
[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'batch_normalization/moving_mean:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'batch_normalization/moving_variance:0' shape=(1,) dtype=float32_ref>]
in our toy example. The last two variables contain the moving averages of mean and variance of the past batches (for which the update ops were evaluated and training
was set to True
). The actual values can be queried as follows:
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
out = sess.run([tf.get_variable('batch_normalization/moving_mean'),
tf.get_variable('batch_normalization/moving_variance')])
moving_average, moving_variance = out