内容摘自:Tensorflow官网
Customize the training step
If you need more flexibility and control, you can have it by implementing your own training loop. There are three steps:
- Iterate over a Python generator or
tf.data.Dataset
to get batches of examples. - Use
tf.GradientTape
to collect gradients. - Use one of the
tf.keras.optimizers
to apply weight updates to the model’s variables.
Remember:
- Always include a
training
argument on thecall
method of subclassed layers and models. - Make sure to call the model with the
training
argument set correctly. - Depending on usage, model variables may not exist until the model is run on a batch of data.
- You need to manually handle things like regularization losses for the model.
Note the simplifications relative to v1:
- There is no need to run variable initializers. Variables are initialized on creation.
- There is no need to add manual control dependencies. Even in
tf.function
operations act as in eager mode.
model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.02), input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dropout(0.1), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Dense(10) ]) optimizer = tf.keras.optimizers.Adam(0.001) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) @tf.function def train_step(inputs, labels): with tf.GradientTape() as tape: predictions = model(inputs, training=True) regularization_loss=tf.math.add_n(model.losses) pred_loss=loss_fn(labels, predictions) total_loss=pred_loss + regularization_loss gradients = tape.gradient(total_loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) for epoch in range(NUM_EPOCHS): for inputs, labels in train_data: train_step(inputs, labels) print("Finished epoch", epoch)
Finished epoch 0 Finished epoch 1 Finished epoch 2 Finished epoch 3 Finished epoch 4
New-style metrics and losses
In TensorFlow 2.0, metrics and losses are objects. These work both eagerly and in tf.function
s.
A loss object is callable, and expects the (y_true, y_pred) as arguments:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True) cce([[1, 0]], [[-1.0,3.0]]).numpy() 4.01815
A metric object has the following methods:
Metric.update_state()
— add new observationsMetric.result()
—get the current result of the metric, given the observed valuesMetric.reset_states()
— clear all observations.
The object itself is callable. Calling updates the state with new observations, as with update_state
, and returns the new result of the metric.
You don’t have to manually initialize a metric’s variables, and because TensorFlow 2.0 has automatic control dependencies, you don’t need to worry about those either.
The code below uses a metric to keep track of the mean loss observed within a custom training loop.
# Create the metrics loss_metric = tf.keras.metrics.Mean(name='train_loss') accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') @tf.function def train_step(inputs, labels): with tf.GradientTape() as tape: predictions = model(inputs, training=True) regularization_loss=tf.math.add_n(model.losses) pred_loss=loss_fn(labels, predictions) total_loss=pred_loss + regularization_loss gradients = tape.gradient(total_loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Update the metrics loss_metric.update_state(total_loss) accuracy_metric.update_state(labels, predictions) for epoch in range(NUM_EPOCHS): # Reset the metrics loss_metric.reset_states() accuracy_metric.reset_states() for inputs, labels in train_data: train_step(inputs, labels) # Get the metric results mean_loss=loss_metric.result() mean_accuracy = accuracy_metric.result() print('Epoch: ', epoch) print(' loss: {:.3f}'.format(mean_loss)) print(' accuracy: {:.3f}'.format(mean_accuracy))
Epoch: 0 loss: 0.207 accuracy: 0.991 Epoch: 1 loss: 0.167 accuracy: 0.994 Epoch: 2 loss: 0.147 accuracy: 0.997 Epoch: 3 loss: 0.123 accuracy: 0.997 Epoch: 4 loss: 0.109 accuracy: 0.997