Thanks to the authors for making their code available. I borrowed some code from DINOv2 and I-JEPA repositories.
@Article{chen2020simsiam,
author = {Xinlei Chen and Kaiming He},
title = {Exploring Simple Siamese Representation Learning},
journal = {arXiv preprint arXiv:2011.10566},
year = {2020},
}
@article{Jing2021UnderstandingDC,
title = {Understanding Dimensional Collapse in Contrastive Self-supervised Learning},
author = {Li Jing and Pascal Vincent and Yann LeCun and Yuandong Tian},
journal = {arXiv preprint arXiv:2110.09348},
year = {2021}
}
@inproceedings{wang2022asym,
title = {On the Importance of Asymmetry for Siamese Representation Learning},
author = {Xiao Wang and Haoqi Fan and Yuandong Tian and Daisuke Kihara and Xinlei Chen},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022}
}
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import types
Config = {
'NUM_CLASSES': 10,
'BATCH_SIZE': 128,
'EPOCHS': 10,
'LR': 0.001, #3e-4,
'IMG_SIZE': 32,
'DATA_MEANS': np.array([0.4914, 0.4822, 0.4465]), # mean of the CIFAR dataset, used for normalization
'DATA_STD': np.array([0.2023, 0.1994, 0.2010]), # standard deviation of the CIFAR dataset, used for normalization
'CROP_SCALES': (0.8, 1.0),
'SEED': 42,
'PATCH_SIZE': 4,
'IN_CHANNELS': 3,
'EMBED_DIM': 256,
'DEPTH': 6,
'NUM_HEADS': 4,
'MLP_DIM': 1024,
'DROPOUT': 0.1,
'HEAD_MLP_DIM': 2048,
'HEAD_DIM': 128,
'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu'
}
# Convert to SimpleNamespace
config = types.SimpleNamespace(**Config)
class_names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Here, I’m passing light augmentations to the teacher dataloader and stronger augmentations to student dataloader. I’m also setting shuffle=False
on trainloaders just to compare the same image for teacher and student networks.
# Transform 1: Light augmentations
transform_teacher = transforms.Compose([
# transforms.RandomCrop(32, padding=4),
# transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(Config['DATA_MEANS'], Config['DATA_STD'])
])
# Transform 2: Stronger augmentations
transform_student = transforms.Compose([
transforms.RandomResizedCrop(32, scale=Config['CROP_SCALES']),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(Config['DATA_MEANS'], Config['DATA_STD']),
transforms.RandomErasing(p=1, scale=(0.4, 0.5), ratio=(0.3, 3.3), value=0)
])
# Create datasets
trainset_teacher = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_teacher)
trainset_student = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_student)
testset_teacher = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_teacher)
testset_student = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_student)
# Create DataLoaders
trainloader_teacher = torch.utils.data.DataLoader(
trainset_teacher, batch_size=Config['BATCH_SIZE'], shuffle=False, num_workers=2)
trainloader_student = torch.utils.data.DataLoader(
trainset_student, batch_size=Config['BATCH_SIZE'], shuffle=False, num_workers=2)
testloader_teacher = torch.utils.data.DataLoader(
testset_teacher, batch_size=Config['BATCH_SIZE'], shuffle=False, num_workers=2)
testloader_student = torch.utils.data.DataLoader(
testset_student, batch_size=Config['BATCH_SIZE'], shuffle=False, num_workers=2)
# Unnormalize function
def unnormalize(img, mean, std):
img = img.numpy().transpose((1, 2, 0))
img = std * img + mean
img = np.clip(img, 0, 1)
return img
# Function to show images from a trainloader
def show_augmented_images(trainloader, mean, std, title):
images, labels = next(iter(trainloader)) # one batch
indices = list(range(0, 10))
plt.figure(figsize=(12, 4))
for i, idx in enumerate(indices):
img = unnormalize(images[idx], np.array(mean), np.array(std))
plt.subplot(2, 5, i + 1)
plt.imshow(img)
plt.title(class_names[labels[idx].item()])
plt.axis('off')
plt.suptitle(title)
plt.tight_layout()
plt.show()
show_augmented_images(trainloader_teacher, Config['DATA_MEANS'], Config['DATA_STD'], title="Augmentations from trainloader_teacher")
show_augmented_images(trainloader_student, Config['DATA_MEANS'], Config['DATA_STD'], title="Augmentations from trainloader_student")
Here is how the augmentations look like:
# Patch Embedding
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.num_patches = (img_size // patch_size) ** 2
def forward(self, x):
x = self.proj(x) # (B, E, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, N, E)
return x
# Transformer Encoder Block
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
# ViT Model
class ViT(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim,
depth, num_heads, mlp_dim, dropout, head_mlp_dim, head_dim):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
self.dropout = nn.Dropout(config.DROPOUT)
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Sequential(
nn.Linear(embed_dim, head_mlp_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(head_mlp_dim, head_mlp_dim),
nn.Dropout(dropout),
nn.Linear(head_mlp_dim, head_dim)
)
def forward(self, x):
B = x.size(0)
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.dropout(x)
x = self.blocks(x)
x = self.norm(x)
return self.head(x)
model_student = ViT(
img_size = Config['IMG_SIZE'],
patch_size = Config['PATCH_SIZE'],
in_channels = Config['IN_CHANNELS'],
embed_dim = Config['EMBED_DIM'],
depth = Config['DEPTH'],
num_heads = Config['NUM_HEADS'],
mlp_dim = Config['MLP_DIM'],
dropout = Config['DROPOUT'],
head_mlp_dim = Config['HEAD_MLP_DIM'],
head_dim = Config['HEAD_DIM']
).to(config.DEVICE)
model_teacher = ViT(
img_size = Config['IMG_SIZE'],
patch_size = Config['PATCH_SIZE'],
in_channels = Config['IN_CHANNELS'],
embed_dim = Config['EMBED_DIM'],
depth = Config['DEPTH'],
num_heads = Config['NUM_HEADS'],
mlp_dim = Config['MLP_DIM'],
dropout = Config['DROPOUT'],
head_mlp_dim = Config['HEAD_MLP_DIM'],
head_dim = Config['HEAD_DIM']
).to(config.DEVICE)
optimizer_student = torch.optim.AdamW(model_student.parameters(), lr=config.LR)
optimizer_teacher = torch.optim.AdamW(model_teacher.parameters(), lr=config.LR)
total_loss_plot = []
for epoch in range(config.EPOCHS):
model_student.train()
model_teacher.train()
total_loss = 0
batch_idx = 1
for (data_student, data_teacher) in zip(trainloader_student, trainloader_teacher):
inputs_student, _ = data_student
inputs_teacher, _ = data_teacher
inputs_student, inputs_teacher = inputs_student.to(config.DEVICE), inputs_teacher.to(config.DEVICE)
optimizer_student.zero_grad()
optimizer_teacher.zero_grad()
outputs_student = model_student(inputs_student)
outputs_student = outputs_student[:, 1:, :] # pulling all tokens except the cls_token
outputs_student = F.log_softmax(outputs_student, dim=-1)
outputs_teacher = model_teacher(inputs_teacher)
outputs_teacher = outputs_teacher[:, 1:, :] # pulling all tokens except the cls_token
outputs_teacher = F.softmax(outputs_teacher, dim=-1)
loss_pointwise = -1 * torch.sum(outputs_teacher * outputs_student, dim=-1)
loss = loss_pointwise.mean()
total_loss += loss.item()
loss.backward() # SGD for student and teacher
optimizer_student.step()
optimizer_teacher.step()
print(f"Epoch [{epoch+1}/{config.EPOCHS}] Batch [{batch_idx}/{len(trainloader_student)}] Loss: {loss.item():.8f}")
batch_idx += 1
total_loss_plot.append(total_loss/len(trainloader_student))
print(f"Epoch [{epoch+1}/{config.EPOCHS}], Loss: {total_loss/len(trainloader_student):.8f}")
COLLAPSE!!
Here is the loss after using SGD to update both the student and teacher networks:
optimizer_student = torch.optim.AdamW(model_student.parameters(), lr=config.LR)
optimizer_teacher = torch.optim.AdamW(model_teacher.parameters(), lr=config.LR)
total_loss_plot = []
for epoch in range(config.EPOCHS):
model_student.train()
model_teacher.train()
total_loss = 0
batch_idx = 1
for (data_student, data_teacher) in zip(trainloader_student, trainloader_teacher):
inputs_student, _ = data_student
inputs_teacher, _ = data_teacher
inputs_student, inputs_teacher = inputs_student.to(config.DEVICE), inputs_teacher.to(config.DEVICE)
optimizer_student.zero_grad()
# optimizer_teacher.zero_grad()
outputs_student = model_student(inputs_student)
outputs_student = outputs_student[:, 1:, :] # pulling all tokens except the cls_token
outputs_student = F.log_softmax(outputs_student, dim=-1)
with torch.no_grad():
outputs_teacher = model_teacher(inputs_teacher)
outputs_teacher = outputs_teacher[:, 1:, :] # pulling all tokens except the cls_token
outputs_teacher = F.softmax(outputs_teacher, dim=-1)
loss_pointwise = -1 * torch.sum(outputs_teacher * outputs_student, dim=-1)
loss = loss_pointwise.mean()
total_loss += loss.item()
loss.backward() # SGD for student and teacher
optimizer_student.step()
# optimizer_teacher.step()
print(f"Epoch [{epoch+1}/{config.EPOCHS}] Batch [{batch_idx}/{len(trainloader_student)}] Loss: {loss.item():.8f}")
batch_idx += 1
total_loss_plot.append(total_loss/len(trainloader_student))
print(f"Epoch [{epoch+1}/{config.EPOCHS}], Loss: {total_loss/len(trainloader_student):.8f}")
Here is the loss after stopping SGD updates on the teacher network. The loss didn’t collapse to zero but stopped improving because the teacher parameters were not changing.
optimizer_student = torch.optim.AdamW(model_student.parameters(), lr=config.LR)
optimizer_teacher = torch.optim.AdamW(model_teacher.parameters(), lr=config.LR)
total_loss_plot = []
for epoch in range(config.EPOCHS):
model_student.train()
model_teacher.train()
total_loss = 0
batch_idx = 1
for (data_student, data_teacher) in zip(trainloader_student, trainloader_teacher):
inputs_student, _ = data_student
inputs_teacher, _ = data_teacher
inputs_student, inputs_teacher = inputs_student.to(config.DEVICE), inputs_teacher.to(config.DEVICE)
optimizer_student.zero_grad()
# optimizer_teacher.zero_grad()
outputs_student = model_student(inputs_student)
outputs_student = outputs_student[:, 1:, :] # pulling all tokens except the cls_token
outputs_student = F.log_softmax(outputs_student, dim=-1)
with torch.no_grad():
outputs_teacher = model_teacher(inputs_teacher)
outputs_teacher = outputs_teacher[:, 1:, :] # pulling all tokens except the cls_token
outputs_teacher = F.softmax(outputs_teacher, dim=-1)
loss_pointwise = -1 * torch.sum(outputs_teacher * outputs_student, dim=-1)
loss = loss_pointwise.mean()
total_loss += loss.item()
loss.backward() # SGD for student and teacher
optimizer_student.step()
# optimizer_teacher.step()
print(f"Epoch [{epoch+1}/{config.EPOCHS}] Batch [{batch_idx}/{len(trainloader_student)}] Loss: {loss.item():.8f}")
batch_idx += 1
state_dict_student = model_student.state_dict()
state_dict_teacher = model_teacher.state_dict()
for name in state_dict_teacher:
state_dict_teacher[name].copy_(state_dict_student[name])
model_teacher.load_state_dict(state_dict_teacher)
total_loss_plot.append(total_loss/len(trainloader_student))
print(f"Epoch [{epoch+1}/{config.EPOCHS}], Loss: {total_loss/len(trainloader_student):.8f}")
Here is the loss after copying the student parameters to teacher network after each epoch. The loss is heading to the right direction. The DINO practice here is not to copy the parameters over, instead the new teacher parameters are the result of adding ($\lambda$ * teacher parameters) and ($1 - \lambda$ * student parameters). $\lambda$ goes from 0.996 to 1 according to a cosine scheduler.
Here I attempted to create a figure like figure 2 in DirectCLR paper. Unfortunately, I was not successful. I did not get the same drop they had in the paper. This might be because they were using contrastive learning with negative pairs, and I was not using negative pairs.
Here is figure 2 from DirectCLR paper:
source: Jing et al. (2021)
Here is what I’ve done:
# inside the training loop
with torch.no_grad():
for (data_student, data_teacher) in zip(testloader_student, testloader_teacher):
inputs_student, _ = data_student
inputs_teacher, _ = data_teacher
inputs_student, inputs_teacher = inputs_student.to(config.DEVICE), inputs_teacher.to(config.DEVICE)
optimizer_student.zero_grad()
optimizer_teacher.zero_grad()
outputs_student = model_student(inputs_student)
outputs_student = outputs_student[:, 1:, :]
outputs_teacher = model_teacher(inputs_teacher)
outputs_teacher = outputs_teacher[:, 1:, :]
all_outputs_student.append(outputs_student.flatten(start_dim=0, end_dim=1).detach().cpu())
all_outputs_teacher.append(outputs_teacher.flatten(start_dim=0, end_dim=1).detach().cpu())
# Concatenate all outputs and save as NumPy array
all_outputs_student_np = torch.cat(all_outputs_student).numpy()
all_outputs_teacher_np = torch.cat(all_outputs_teacher).numpy()
X_centered = all_outputs_student_np - np.mean(all_outputs_student_np, axis=0)
cov_matrix = np.cov(X_centered, rowvar=False)
_, S_student, _ = np.linalg.svd(cov_matrix)
X_centered = all_outputs_teacher_np - np.mean(all_outputs_teacher_np, axis=0)
cov_matrix = np.cov(X_centered, rowvar=False)
_, S_teacher, _ = np.linalg.svd(cov_matrix)
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
axes[0].plot(np.log(S_student), marker='o', linewidth=3)
axes[0].set_title('Student')
axes[0].set_ylabel('Log of Singular Value')
axes[1].plot(np.log(S_teacher), marker='o', linewidth=3)
axes[1].set_title('Teacher')
axes[1].set_xlabel('Singular Value Index')
axes[1].set_ylabel('Log of Singular Value')
But unfortunately, for all three cases above, I got a similar singular values plot.
Teacher SGD on:
Teacher SGD off:
Student parameters copied to the teacher: