Now it’s time for some contrastive learning. To mitigate the issue of insufficient annotation labels and fully utilize the large quantity of unlabelled data, contrastive learning could be used to effectively help the backbone learn the data representations without a specific task. The backbone could be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to achieve satisfactory results.
The most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.
SimCLR calculates over positive and negative pairs within the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a large batch size. SimCLR also requires the LARS optimizer to accommodate a large batch size.
SimSiam, however, uses a Siamese architecture, which avoids using negative pairs and further avoids the need for large batch sizes. The differences between SimSiam and SimCLR are given in the table below.
We can see from the figure above that the SimSiam architecture only contains two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.
So, how do we implement this architecture in reality? Continuing on the supervised classification design, we keep the backbone the same and only modify the MLP layer. In the supervised learning architecture, the MLP outputs a 10-element vector indicating the probabilities of the 10 classes. But for SimSiam, the purpose is not to perform “classification” but to learn the “representation,” so we need the output to be of the same dimension as the backbone output for loss calculation. And the negative_cosine_similarity is given below:
import torch.nn as nn
import matplotlib.pyplot as pltclass SimSiam(nn.Module):
def __init__(self):
super(SimSiam, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 128*4*4),
)
def forward(self, x):
x = self.backbone(x)
x = x.view(-1, 128 * 4 * 4)
pred_output = self.prediction_mlp(x)
return x, pred_output
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def negative_cosine_similarity_stopgradient(pred, proj):
return -cos(pred, proj.detach()).mean()
The pseudo-code for training the SimSiam is given in the original paper below:
And we convert it into real training code:
import tqdmimport torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import RandAugment
import wandb
wandb_config = {
"learning_rate": 0.0001,
"architecture": "simsiam",
"dataset": "FashionMNIST",
"epochs": 100,
"batch_size": 256,
}
wandb.init(
# set the wandb project where this run will be logged
project="simsiam",
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simsiam = SimSiam()
random_augmenter = RandAugment(num_ops=5)
optimizer = optim.SGD(simsiam.parameters(),
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)
train_dataloader = DataLoader(train_dataset, batch_size=wandb_config["batch_size"], shuffle=True)
# Training loop
for epoch in range(wandb_config["epochs"]):
simsiam.train()
print(f"Epoch {epoch}")
train_loss = 0
for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()
aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0, \
random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0
proj1, pred1 = simsiam(aug1)
proj2, pred2 = simsiam(aug2)
loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2
loss.backward()
optimizer.step()
wandb.log({"training loss": loss})
if (epoch+1) % 10 == 0:
torch.save(simsiam.state_dict(), f"weights/simsiam_epoch{epoch+1}.pt")
We trained for 100 epochs as a fair comparison to the limited supervised training; the training loss is shown below. Note: Due to its Siamese design, SimSiam could be very sensitive to hyperparameters like learning rate and MLP hidden layers. The original SimSiam paper provides a detailed configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.
Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:
import tqdm
import numpy as npimport torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simsiam = SimSiam()
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
simsiam.load_state_dict(torch.load("weights/simsiam_epoch100.pt"))
simsiam.eval()
simsiam.to(device)
features = []
labels = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
with torch.no_grad():
proj, pred = simsiam(image.to(device))
features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())
labels.extend(target.detach().cpu().numpy().tolist())
import plotly.express as px
import umap.umap_ as umap
reducer = umap.UMAP(n_components=3, n_neighbors=10, metric="cosine")
projections = reducer.fit_transform(np.array(features))
px.scatter(projections, x=0, y=1,
color=labels, labels={'color': 'Fashion MNIST Labels'}
)
It’s interesting to see that there are two small islands in the reduced-dimension map above: class 5, 7, 8, and some 9. If we pull out the FashionMNIST class list, we know that these classes correspond to footwear such as “Sandal,” “Sneaker,” “Bag,” and “Ankle boot.” The big purple cluster corresponds to clothing classes like “T-shirt/top,” “Trousers,” “Pullover,” “Dress,” “Coat,” and “Shirt.” The SimSiam demonstrates learning a meaningful representation in the vision domain.