I borrowed some code from DINOv2 repository.
@inproceedings{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J{\'e}gou, Herv{\'e} and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
booktitle = {Proceedings of the International Conference on Computer Vision (ICCV)},
year = {2021}
}
@article{zhou2021ibot,
title = {iBOT: Image BERT Pre-Training with Online Tokenizer},
author = {Zhou, Jinghao and Wei, Chen and Wang, Huiyu and Shen, Wei and Xie, Cihang and Yuille, Alan and Kong, Tao},
journal = {International Conference on Learning Representations (ICLR)},
year = {2022}
}
@misc{oquab2023dinov2,
title = {DINOv2: Learning Robust Visual Features without Supervision},
author = {Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
journal = {arXiv:2304.07193},
year = {2023}
}
global_crops
and local_crops
global_crops
patchesThe student network is updated with standard SGD, while the teacher network uses an exponential moving average (EMA) according to the following update rule:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
import os
from sklearn.linear_model import LogisticRegression
Config = {
'NUM_CLASSES': 10,
'BATCH_SIZE': 128,
'NUM_EPOCHS': 100,
'LR': 0.001,
'WIDTH': 224,
'HEIGHT': 224,
'DATA_MEANS': np.array([0.49139968, 0.48215841, 0.44653091]), # mean of the CIFAR dataset, used for normalization
'DATA_STD': np.array([0.24703223, 0.24348513, 0.26158784]), # standard deviation of the CIFAR dataset, used for normalization
'CROP_SCALES': (0.8, 1.0),
'CROP_RATIO': (0.9, 1.1),
'SEED': 42,
}
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Load CIFAR-10 dataset
transform = transforms.Compose([
transforms.Resize((Config['HEIGHT'], Config['WIDTH'])), # DINOv2 requires 224x224 input
transforms.ToTensor(),
transforms.Normalize(Config['DATA_MEANS'], Config['DATA_STD'])])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Config['BATCH_SIZE'], shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=Config['BATCH_SIZE'], shuffle=False, num_workers=2)
# Number of images to display
num_images = 10
# Create a figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))
# Select 10 random indices
random_indices = random.sample(range(len(trainloader.dataset)), num_images)
# Iterate through the random images and display them
for i, index in enumerate(random_indices):
img, label = trainloader.dataset[index]
# Unnormalize (reverse the normalization for display)
mean = torch.tensor(Config['DATA_MEANS']).view(3, 1, 1)
std = torch.tensor(Config['DATA_STD']).view(3, 1, 1)
img = img * std + mean # Undo normalization
img = np.transpose(img.numpy(), (1, 2, 0)) # Convert from CHW to HWC
axes[i].imshow(np.clip(img, 0, 1))
axes[i].set_title(class_names[label])
axes[i].axis('off')
plt.tight_layout()
plt.show()
# DINOv2
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
dinov2_vits14.eval() # Set to evaluation mode (no fine-tuning)
Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main
DinoVisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
(norm): Identity()
)
(blocks): ModuleList(
(0-11): 12 x NestedTensorBlock(
(norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(attn): MemEffAttention(
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): LayerScale()
(drop_path1): Identity()
(norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=384, out_features=1536, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1536, out_features=384, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
(ls2): LayerScale()
(drop_path2): Identity()
)
)
(norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
(head): Identity()
)
# Function to extract features using DINOv2
def extract_features(model, dataloader):
features = []
labels = []
with torch.no_grad():
for images, lbls in dataloader:
images = images.to(device)
feats = model(images) # Shape: (batch_size, 384) for dinov2_vits14
features.append(feats.cpu().numpy())
labels.append(lbls.numpy())
return np.concatenate(features), np.concatenate(labels)
# Training features: Load if exists, extract if not
if os.path.exists("train_features.npy") and os.path.exists("train_labels.npy"):
print("Loading precomputed training features...")
train_features = np.load("train_features.npy")
train_labels = np.load("train_labels.npy")
else:
print("Extracting training features...")
train_features, train_labels = extract_features(dinov2_vits14, trainloader, save_path="train")
# Test features: Load if exists, extract if not
if os.path.exists("test_features.npy") and os.path.exists("test_labels.npy"):
print("Loading precomputed test features...")
test_features = np.load("test_features.npy")
test_labels = np.load("test_labels.npy")
else:
print("Extracting test features...")
test_features, test_labels = extract_features(dinov2_vits14, testloader, save_path="test")
# Train a simple classifier (logistic regression) on the features
classifier = LogisticRegression(max_iter=1000)
classifier.fit(train_features, train_labels)
# Evaluate the classifier
train_accuracy = classifier.score(train_features, train_labels)
test_accuracy = classifier.score(test_features, test_labels)
print(f"Training accuracy: {train_accuracy:.4f}")
print(f"Test accuracy: {test_accuracy:.4f}")
Training accuracy: 0.9832
Test accuracy: 0.9501
# Updated classify_image to display image with labels
def classify_image(image_tensor, actual_label, model, classifier, class_names):
# Get prediction
img = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
feature = model(img).cpu().numpy()
pred = classifier.predict(feature)
pred_class = class_names[pred[0]]
actual_class = class_names[actual_label]
# Unnormalize image for display
mean = torch.tensor(Config['DATA_MEANS']).view(3, 1, 1)
std = torch.tensor(Config['DATA_STD']).view(3, 1, 1)
img_display = image_tensor * std + mean # Undo normalization
img_display = img_display.numpy().transpose(1, 2, 0) # CHW to HWC
img_display = np.clip(img_display, 0, 1) # Ensure valid range
# Display image with title
plt.figure(figsize=(4, 4))
plt.imshow(img_display)
plt.title(f"Actual: {actual_class}\nPredicted: {pred_class}", fontsize=12)
plt.axis('off')
plt.show()
return pred_class
# Test on a random test image
test_image, test_label = testset[random.randint(0, len(testset))]
pred_class = classify_image(test_image, test_label, dinov2_vits14, classifier, class_names)
While the classifier achieves 95% test accuracy, I’m pulling out some of its incorrect predictions to understand what confuses the classifier.