Have you thought about how the performance of your ML models can be enhanced without developing new models? That is where transfer learning comes into play. In this article, we will provide an overview of transfer learning along with its benefits and challenges.
What is Transfer Learning?
Transfer learning means that a model trained for one task can be used for another similar task. You can then use a pre-trained model and make changes in it according to the required task. Let’s discuss the stages in transfer learning.
- Choose a Pre-trained model: Select a model that has been trained on a large dataset for a similar task to the one you want to work on.
- Modify model architecture: Adjust the final layers of the pre-trained model according to your specific task. Also, add new layers if needed.
- Re-train the model: Train the modified model on your new dataset. This allows the model to learn the details of your specific task. It also benefits from the features it learned during the original training.
- Fine-tune the model: Unfreeze some of the pre-trained layers and continue training your model. This allows the model to better adapt to the new task by fine-tuning its weights.
Benefits of Transfer Learning
Transfer learning offers several significant advantages:
- Saves Time and Resources: Fine-tuning needs lesser time and computational resources as the pre-trained model has been initially trained for a large number of iterations for a specific dataset. This process has already captured essential features, so it reduces the workload for the new task.
- Improves Performance: Pre-trained models have learned from extensive datasets, so they generalize better. This leads to improved performance on new tasks, even when the new dataset is relatively small. The knowledge gained from the initial training helps in achieving higher accuracy and better results.
- Needs Less Data: One of the major benefits of transfer learning is its effectiveness with smaller datasets. The pre-trained model has already acquired useful pattern and features information. Thus, it can perform fairly even if it is given few new data.
Types of Transfer Learning
Transfer learning can be classified into three types:
Feature extraction
Feature extraction means using features learned by a model on new data. For instance, in image classification, we can utilize features from a predefined Convolutional Neural Network to search for significant features in images. Here’s an example using a pre-trained VGG16 model from Keras for image feature extraction:
import numpy as np
from tensorflow.keras.applications import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
# Load pre-trained VGG16 model (without the top layers for classification)
base_model = VGG16(weights="imagenet", include_top=False)
# Function to extract features from an image
def extract_features(img_path):
img = image.load_img(img_path, target_size=(224, 224)) # Load image and resize
x = image.img_to_array(img) # Convert image to numpy array
x = np.expand_dims(x, axis=0) # Add batch dimension
x = preprocess_input(x) # Preprocess input according to model's requirements
features = base_model.predict(x) # Extract features using VGG16 model
return features.flatten() # Flatten to a 1D array for simplicity
# Example usage
image_path="path_to_your_image.jpg"
image_features = extract_features(image_path)
print(f"Extracted features shape: {image_features.shape}")
Fine-tuning
Fine-tuning involves tweaking the feature extraction steps and the aspects of a new model matching the specific task. This method is most useful with a mid-sized data set and where you wish to enhance a particular task-related ability of the model. For example, in NLP, a standard BERT model might be adjusted or further trained on a small quantity of medical texts to accomplish medical entity recognition better. Here’s an example using BERT for sentiment analysis with fine-tuning on a custom dataset:
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
import torch
from torch.utils.data import DataLoader, TensorDataset
# Example data (replace with your dataset)
texts = ["I love this product!", "This is not what I expected.", ...]
labels = [1, 0, ...] # 1 for positive sentiment, 0 for negative sentiment, etc.
# Load pre-trained BERT model and tokenizer
model_name="bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) # Example: binary classification
# Tokenize input texts and create DataLoader
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
dataset = TensorDataset(inputs['input_ids'], inputs['attention_mask'], torch.tensor(labels))
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# Fine-tuning parameters
optimizer = AdamW(model.parameters(), lr=1e-5)
# Fine-tune BERT model
model.train()
for epoch in range(3): # Example: 3 epochs
for batch in dataloader:
optimizer.zero_grad()
input_ids, attention_mask, target = batch
outputs = model(input_ids, attention_mask=attention_mask, labels=target)
loss = outputs.loss
loss.backward()
optimizer.step()
Domain adaptation
Domain adaptation gives an insight on how one can utilize knowledge gained from the source domain that the pre-trained model was trained on to the different target domain. This is required when the source and target domains differ on the features, the data distribution, or even on the language. For instance, in sentiment analysis we may apply a sentiment classifier learned from product reviews into social media posts because the two uses very different language. Here’s an example using sentiment analysis, adapting from product reviews to social media posts:
# Function to adapt text style
def adapt_text_style(text):
# Example: replace social media language with product review-like language
adapted_text = text.replace("excited", "positive").replace("#innovation", "new technology")
return adapted_text
# Example usage of domain adaptation
social_media_post = "Excited about the new tech! #innovation"
adapted_text = adapt_text_style(social_media_post)
print(f"Adapted text: {adapted_text}")
# Use sentiment classifier trained on product reviews
# Example: sentiment_score = sentiment_classifier.predict(adapted_text)
Pre-trained Models
Pretrained models are models already trained on large datasets. They capture knowledge and patterns from extensive data. These models are used as a starting point for other tasks. Let’s discuss some of the common pre-trained models used in machine learning: applications.
VGG (Visual Geometry Group)
The architecture of VGG include multiple layers of 3×3 convolutional filters and pooling layers. It is able to identify detailed features like edges and shapes in images. By training on large datasets, VGG learns to recognize different objects within images. It can used for object detection and image segmentation.
ResNet (Residual Network)
ResNet uses residual connections to train models. These connections make it easier for gradients to flow through the network. This prevents the vanishing gradient problem, helping the network train effectively. ResNet can successfully train models with hundreds of layers. ResNet is excellent for tasks such as image classification and face recognition.
BERT (Bidirectional Encoder Representations from Transformers)
BERT is used for natural language processing applications. It uses a transformer-based model to understand the context of words in a sentence. It learns to guess missing words and understand sentence meanings. BERT can be used for sentiment analysis, question answering and named entity recognition.
Fine-tuning Techniques
Layer Freezing
Layer freezing means choosing certain layers of a pre-trained model and preventing them from changing during training with new data. This is done to preserve the useful patterns and features the model learned from its original training. Typically, we freeze early layers that capture general features like edges in images or basic structures in text.
Learning Rate Adjustment
Tuning the learning rate is important to balance what the model has learned and new data. Usually, fine-tuning involves using a lower learning rate than in the initial training with large datasets. This helps the model adapt to new data while preserving most of its learned weights.
Challenges and Considerations
Let’s discuss the challenges of transfer learning and how to address them.
- Dataset Size and Domain Shift: When fine-tuning, there should be abundant of data for the task concerned while fine-tuning generalized models. The drawback of this approach is that in case the new dataset is either small or significantly different from what fits the model at the beginning. To deal with this, one can put more data which will be more relevant to what the model already trained on.
- Hyperparameter Tuning: Changing hyperparameters is important when working with pre trained models. These parameters are dependent on each other and determine how good the model is going to be. Techniques such as grid search or automated tools to search for the most optimal settings for hyperparameters that would yield high performance on validation data.
- Computational Resources: Fine-tuning of deep neural networks is computationally demanding because such models can have millions of parameters. For training and predicting the output, powerful accelerators like GPU or TPU are required. These demands are usually addressed by the cloud computing platforms.
Wrapping Up
In conclusion, transfer learning stands as a cornerstone in the quest to enhance model performance across diverse applications of artificial intelligence. By leveraging pretrained models like VGG, ResNet, BERT, and others, practitioners can efficiently harness existing knowledge to tackle complex tasks in image classification, natural language processing, healthcare, autonomous systems, and beyond.
Jayita Gulati is a machine learning enthusiast and technical writer driven by her passion for building machine learning models. She holds a Master’s degree in Computer Science from the University of Liverpool.