【记录】Tensorflow 2.0 中更灵活可控的模型训练方式

内容摘自:Tensorflow官网

Customize the training step

If you need more flexibility and control, you can have it by implementing your own training loop. There are three steps:

  1. Iterate over a Python generator or tf.data.Dataset to get batches of examples.
  2. Use tf.GradientTape to collect gradients.
  3. Use one of the tf.keras.optimizers to apply weight updates to the model’s variables.

Remember:

  • Always include a training argument on the call method of subclassed layers and models.
  • Make sure to call the model with the training argument set correctly.
  • Depending on usage, model variables may not exist until the model is run on a batch of data.
  • You need to manually handle things like regularization losses for the model.

Note the simplifications relative to v1:

  • There is no need to run variable initializers. Variables are initialized on creation.
  • There is no need to add manual control dependencies. Even in tf.function operations act as in eager mode.
  model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
  ])

  optimizer = tf.keras.optimizers.Adam(0.001)
  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

  @tf.function
  def train_step(inputs, labels):
    with tf.GradientTape() as tape:
      predictions = model(inputs, training=True)
      regularization_loss=tf.math.add_n(model.losses)
      pred_loss=loss_fn(labels, predictions)
      total_loss=pred_loss + regularization_loss

    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for epoch in range(NUM_EPOCHS):
    for inputs, labels in train_data:
      train_step(inputs, labels)
    print("Finished epoch", epoch)
Finished epoch 0
Finished epoch 1
Finished epoch 2
Finished epoch 3
Finished epoch 4

New-style metrics and losses

In TensorFlow 2.0, metrics and losses are objects. These work both eagerly and in tf.functions.

A loss object is callable, and expects the (y_true, y_pred) as arguments:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()

4.01815

A metric object has the following methods:

The object itself is callable. Calling updates the state with new observations, as with update_state, and returns the new result of the metric.

You don’t have to manually initialize a metric’s variables, and because TensorFlow 2.0 has automatic control dependencies, you don’t need to worry about those either.

The code below uses a metric to keep track of the mean loss observed within a custom training loop.

  # Create the metrics
  loss_metric = tf.keras.metrics.Mean(name='train_loss')
  accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

  @tf.function
  def train_step(inputs, labels):
    with tf.GradientTape() as tape:
      predictions = model(inputs, training=True)
      regularization_loss=tf.math.add_n(model.losses)
      pred_loss=loss_fn(labels, predictions)
      total_loss=pred_loss + regularization_loss

    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    # Update the metrics
    loss_metric.update_state(total_loss)
    accuracy_metric.update_state(labels, predictions)


  for epoch in range(NUM_EPOCHS):
    # Reset the metrics
    loss_metric.reset_states()
    accuracy_metric.reset_states()

    for inputs, labels in train_data:
      train_step(inputs, labels)
    # Get the metric results
    mean_loss=loss_metric.result()
    mean_accuracy = accuracy_metric.result()

    print('Epoch: ', epoch)
    print('  loss:     {:.3f}'.format(mean_loss))
    print('  accuracy: {:.3f}'.format(mean_accuracy))
Epoch:  0
  loss:     0.207
  accuracy: 0.991
Epoch:  1
  loss:     0.167
  accuracy: 0.994
Epoch:  2
  loss:     0.147
  accuracy: 0.997
Epoch:  3
  loss:     0.123
  accuracy: 0.997
Epoch:  4
  loss:     0.109
  accuracy: 0.997

留下评论

您的邮箱地址不会被公开。 必填项已用 * 标注