Sling Academy
Home/Tensorflow/TensorFlow Train: Fine-Tuning Models with Pretrained Weights

TensorFlow Train: Fine-Tuning Models with Pretrained Weights

Last updated: December 21, 2024

Fine-tuning a machine learning model entails adapting a pretrained model to a new problem. Utilizing pretrained weights can significantly shorten the time required for training and result in superior accuracy. TensorFlow, an end-to-end open-source platform for machine learning, provides tools for comprehensive model training and fine-tuning. In this article, we'll explore how to leverage TensorFlow to fine-tune models using pretrained weights.

Getting Started with TensorFlow

Before diving into fine-tuning, ensure you have TensorFlow properly installed. You can install it via pip if it's not already available in your Python environment:

pip install tensorflow

This command will install the appropriate version of TensorFlow compatible with your system’s configuration.

Loading a Pretrained Model

TensorFlow offers a variety of pretrained models. The tf.keras.applications module is an excellent resource for this purpose. Here’s an example of how to load a pretrained MobileNetV2 model:

import tensorflow as tf

# Load the MobileNetV2 model with pretrained weights from ImageNet
dense_model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False)

The above code snippet initializes MobileNetV2 without the top layer, which is crucial for customizing the head layers as per the new task.

Preparing Your Dataset

Having an accurate and well-prepared dataset is essential for fine-tuning. Here's how you can load and preprocess your own image dataset utilizing TensorFlow's tf.keras.utils.image_dataset_from_directory:

train_dataset = tf.keras.utils.image_dataset_from_directory(
    'path_to_data/train',
    image_size=(224, 224),
    batch_size=32
)

validation_dataset = tf.keras.utils.image_dataset_from_directory(
    'path_to_data/validation',
    image_size=(224, 224),
    batch_size=32
)

This approach quickly loads and processes images in a format that the model can absorb during training and validation.

Freezing Base Layers

Freezing the base model layers is crucial for retaining pretrained features and facilitating the learning of only the new layers:

for layer in dense_model.layers:
    layer.trainable = False

This command disables training on all layers of the base model, preventing any updates to their weights.

Customizing the Model

Now, you'll create a new top layer that matches the specific requirements of your dataset. Suppose you need to classify 10 different categories:

inputs = tf.keras.Input(shape=(224, 224, 3))

x = dense_model(inputs, training=False)  # Run inputs through base model
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)

This new model now incorporates trained layers suitable for classifying your specified number of categories.

Compiling and Training

After model adjustment, it must be compiled before training. Compiling sets the optimizer, loss, and evaluation metrics:

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Now, execute the training:

history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=10
)

The above code trains the model, evaluates on the validation dataset, and stores metrics throughout the epochs, giving insight into performance improvements.

Unfreezing and Fine-tuning

Upon achieving satisfactory initial results, you can choose to unfreeze some layers of the base model for more in-depth training. Make sure to use a lower learning rate:

for layer in dense_model.layers[:100]:  # Unfreeze only the top 100 layers
dense_model.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

fine_tune_history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=5
)

This technique typically leads to an enhanced model performance as the network learns complex patterns not captured before.

Fine-tuning with TensorFlow, by utilizing pretrained weights, offers an efficient way to achieve excellent model performance on custom datasets. The modular nature of TensorFlow and its pretrained models expedites adapting general models to specific tasks, unlocking significant new capabilities across numerous domains.

Next Article: TensorFlow Train: Advanced Training Techniques for Faster Convergence

Previous Article: TensorFlow Train: Debugging Issues in Model Training

Series: Tensorflow Tutorials

Tensorflow

You May Also Like

  • TensorFlow `scalar_mul`: Multiplying a Tensor by a Scalar
  • TensorFlow `realdiv`: Performing Real Division Element-Wise
  • Tensorflow - How to Handle "InvalidArgumentError: Input is Not a Matrix"
  • TensorFlow `TensorShape`: Managing Tensor Dimensions and Shapes
  • TensorFlow Test: How to Test TensorFlow Layers
  • TensorFlow Test: Best Practices for Testing Neural Networks
  • TensorFlow Summary: Debugging Models with TensorBoard
  • Debugging with TensorFlow Profiler’s Trace Viewer
  • TensorFlow dtypes: Choosing the Best Data Type for Your Model
  • TensorFlow: Fixing "ValueError: Tensor Initialization Failed"
  • Debugging TensorFlow’s "AttributeError: 'Tensor' Object Has No Attribute 'tolist'"
  • TensorFlow: Fixing "RuntimeError: TensorFlow Context Already Closed"
  • Handling TensorFlow’s "TypeError: Cannot Convert Tensor to Scalar"
  • TensorFlow: Resolving "ValueError: Cannot Broadcast Tensor Shapes"
  • Fixing TensorFlow’s "RuntimeError: Graph Not Found"
  • TensorFlow: Handling "AttributeError: 'Tensor' Object Has No Attribute 'to_numpy'"
  • Debugging TensorFlow’s "KeyError: TensorFlow Variable Not Found"
  • TensorFlow: Fixing "TypeError: TensorFlow Function is Not Iterable"
  • Resolving TensorFlow’s "ValueError: Invalid Tensor Initialization"