内容摘自:Tensorflow官网
Customize the training step
If you need more flexibility and control, you can have it by implementing your own training loop. There are three steps:
- Iterate over a Python generator or
tf.data.Datasetto get batches of examples. - Use
tf.GradientTapeto collect gradients. - Use one of the
tf.keras.optimizersto apply weight updates to the model’s variables.
Remember:
- Always include a
trainingargument on thecallmethod of subclassed layers and models. - Make sure to call the model with the
trainingargument set correctly. - Depending on usage, model variables may not exist until the model is run on a batch of data.
- You need to manually handle things like regularization losses for the model.
Note the simplifications relative to v1:
- There is no need to run variable initializers. Variables are initialized on creation.
- There is no need to add manual control dependencies. Even in
tf.functionoperations act as in eager mode.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
Finished epoch 0 Finished epoch 1 Finished epoch 2 Finished epoch 3 Finished epoch 4
New-style metrics and losses
In TensorFlow 2.0, metrics and losses are objects. These work both eagerly and in tf.functions.
A loss object is callable, and expects the (y_true, y_pred) as arguments:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True) cce([[1, 0]], [[-1.0,3.0]]).numpy() 4.01815
A metric object has the following methods:
Metric.update_state()— add new observationsMetric.result()—get the current result of the metric, given the observed valuesMetric.reset_states()— clear all observations.
The object itself is callable. Calling updates the state with new observations, as with update_state, and returns the new result of the metric.
You don’t have to manually initialize a metric’s variables, and because TensorFlow 2.0 has automatic control dependencies, you don’t need to worry about those either.
The code below uses a metric to keep track of the mean loss observed within a custom training loop.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
Epoch: 0 loss: 0.207 accuracy: 0.991 Epoch: 1 loss: 0.167 accuracy: 0.994 Epoch: 2 loss: 0.147 accuracy: 0.997 Epoch: 3 loss: 0.123 accuracy: 0.997 Epoch: 4 loss: 0.109 accuracy: 0.997