内容摘自:Tensorflow官网
Customize the training step
If you need more flexibility and control, you can have it by implementing your own training loop. There are three steps:
- Iterate over a Python generator or tf.data.Datasetto get batches of examples.
- Use tf.GradientTapeto collect gradients.
- Use one of the tf.keras.optimizersto apply weight updates to the model’s variables.
Remember:
- Always include a trainingargument on thecallmethod of subclassed layers and models.
- Make sure to call the model with the trainingargument set correctly.
- Depending on usage, model variables may not exist until the model is run on a batch of data.
- You need to manually handle things like regularization losses for the model.
Note the simplifications relative to v1:
- There is no need to run variable initializers. Variables are initialized on creation.
- There is no need to add manual control dependencies. Even in tf.functionoperations act as in eager mode.
  model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
  ])
  optimizer = tf.keras.optimizers.Adam(0.001)
  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  @tf.function
  def train_step(inputs, labels):
    with tf.GradientTape() as tape:
      predictions = model(inputs, training=True)
      regularization_loss=tf.math.add_n(model.losses)
      pred_loss=loss_fn(labels, predictions)
      total_loss=pred_loss + regularization_loss
    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  for epoch in range(NUM_EPOCHS):
    for inputs, labels in train_data:
      train_step(inputs, labels)
    print("Finished epoch", epoch)
Finished epoch 0 Finished epoch 1 Finished epoch 2 Finished epoch 3 Finished epoch 4
New-style metrics and losses
In TensorFlow 2.0, metrics and losses are objects. These work both eagerly and in tf.functions. 
A loss object is callable, and expects the (y_true, y_pred) as arguments:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True) cce([[1, 0]], [[-1.0,3.0]]).numpy() 4.01815
A metric object has the following methods:
- Metric.update_state()— add new observations
- Metric.result()—get the current result of the metric, given the observed values
- Metric.reset_states()— clear all observations.
The object itself is callable. Calling updates the state with new observations, as with update_state, and returns the new result of the metric.
You don’t have to manually initialize a metric’s variables, and because TensorFlow 2.0 has automatic control dependencies, you don’t need to worry about those either.
The code below uses a metric to keep track of the mean loss observed within a custom training loop.
  # Create the metrics
  loss_metric = tf.keras.metrics.Mean(name='train_loss')
  accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
  @tf.function
  def train_step(inputs, labels):
    with tf.GradientTape() as tape:
      predictions = model(inputs, training=True)
      regularization_loss=tf.math.add_n(model.losses)
      pred_loss=loss_fn(labels, predictions)
      total_loss=pred_loss + regularization_loss
    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    # Update the metrics
    loss_metric.update_state(total_loss)
    accuracy_metric.update_state(labels, predictions)
  for epoch in range(NUM_EPOCHS):
    # Reset the metrics
    loss_metric.reset_states()
    accuracy_metric.reset_states()
    for inputs, labels in train_data:
      train_step(inputs, labels)
    # Get the metric results
    mean_loss=loss_metric.result()
    mean_accuracy = accuracy_metric.result()
    print('Epoch: ', epoch)
    print('  loss:     {:.3f}'.format(mean_loss))
    print('  accuracy: {:.3f}'.format(mean_accuracy))
Epoch: 0 loss: 0.207 accuracy: 0.991 Epoch: 1 loss: 0.167 accuracy: 0.994 Epoch: 2 loss: 0.147 accuracy: 0.997 Epoch: 3 loss: 0.123 accuracy: 0.997 Epoch: 4 loss: 0.109 accuracy: 0.997
