Mashaan blog

Inside the DINOv2 Architecture

Table of Contents

Acknowledgment:

I borrowed some code from DINOv2 repository.

References:

@inproceedings{caron2021emerging,
  title     = {Emerging Properties in Self-Supervised Vision Transformers},
  author    = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J{\'e}gou, Herv{\'e} and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  booktitle = {Proceedings of the International Conference on Computer Vision (ICCV)},
  year      = {2021}
}
@article{zhou2021ibot,
  title     = {iBOT: Image BERT Pre-Training with Online Tokenizer},
  author    = {Zhou, Jinghao and Wei, Chen and Wang, Huiyu and Shen, Wei and Xie, Cihang and Yuille, Alan and Kong, Tao},
  journal   = {International Conference on Learning Representations (ICLR)},
  year      = {2022}
}
@misc{oquab2023dinov2,
  title   = {DINOv2: Learning Robust Visual Features without Supervision},
  author  = {Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
  journal = {arXiv:2304.07193},
  year    = {2023}
}

DINOv2 architecture

global_crops and local_crops

DINOv2-001


DINOv2-002

Masking out global_crops patches

DINOv2-003


ViT backbone

DINOv2-004


DINOv2-005

dino head

DINOv2-006

ibot head

DINOv2-008

dino loss

DINOv2-007

ibot loss

DINOv2-009

weights update

The student network is updated with standard SGD, while the teacher network uses an exponential moving average (EMA) according to the following update rule:

Untitled001

putting it all together

drawings-02 005

Testing DINOv2 on CIFAR-10

Importing libraries and preparing the dataset

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
import os
from sklearn.linear_model import LogisticRegression
Config = {
    'NUM_CLASSES': 10,
    'BATCH_SIZE': 128,
    'NUM_EPOCHS': 100,
    'LR': 0.001,
    'WIDTH': 224,
    'HEIGHT': 224,
    'DATA_MEANS': np.array([0.49139968, 0.48215841, 0.44653091]), # mean of the CIFAR dataset, used for normalization
    'DATA_STD': np.array([0.24703223, 0.24348513, 0.26158784]),   # standard deviation of the CIFAR dataset, used for normalization
    'CROP_SCALES': (0.8, 1.0),
    'CROP_RATIO': (0.9, 1.1),
    'SEED': 42,
}

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class_names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((Config['HEIGHT'], Config['WIDTH'])),  # DINOv2 requires 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(Config['DATA_MEANS'], Config['DATA_STD'])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Config['BATCH_SIZE'], shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=Config['BATCH_SIZE'], shuffle=False, num_workers=2)
# Number of images to display
num_images = 10

# Create a figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))

# Select 10 random indices
random_indices = random.sample(range(len(trainloader.dataset)), num_images)

# Iterate through the random images and display them
for i, index in enumerate(random_indices):
    img, label = trainloader.dataset[index]

    # Unnormalize (reverse the normalization for display)
    mean = torch.tensor(Config['DATA_MEANS']).view(3, 1, 1)
    std = torch.tensor(Config['DATA_STD']).view(3, 1, 1)
    img = img * std + mean  # Undo normalization
    img = np.transpose(img.numpy(), (1, 2, 0))  # Convert from CHW to HWC

    axes[i].imshow(np.clip(img, 0, 1))
    axes[i].set_title(class_names[label])
    axes[i].axis('off')

plt.tight_layout()
plt.show()

image

Downloading the DINO model and getting the features out

# DINOv2
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
dinov2_vits14.eval()  # Set to evaluation mode (no fine-tuning)
Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main
DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)
# Function to extract features using DINOv2
def extract_features(model, dataloader):
    features = []
    labels = []
    with torch.no_grad():
        for images, lbls in dataloader:
            images = images.to(device)
            feats = model(images)  # Shape: (batch_size, 384) for dinov2_vits14
            features.append(feats.cpu().numpy())
            labels.append(lbls.numpy())
    return np.concatenate(features), np.concatenate(labels)
# Training features: Load if exists, extract if not
if os.path.exists("train_features.npy") and os.path.exists("train_labels.npy"):
    print("Loading precomputed training features...")
    train_features = np.load("train_features.npy")
    train_labels = np.load("train_labels.npy")
else:
    print("Extracting training features...")
    train_features, train_labels = extract_features(dinov2_vits14, trainloader, save_path="train")

# Test features: Load if exists, extract if not
if os.path.exists("test_features.npy") and os.path.exists("test_labels.npy"):
    print("Loading precomputed test features...")
    test_features = np.load("test_features.npy")
    test_labels = np.load("test_labels.npy")
else:
    print("Extracting test features...")
    test_features, test_labels = extract_features(dinov2_vits14, testloader, save_path="test")

Training a logistic regression classifier

# Train a simple classifier (logistic regression) on the features
classifier = LogisticRegression(max_iter=1000)
classifier.fit(train_features, train_labels)

# Evaluate the classifier
train_accuracy = classifier.score(train_features, train_labels)
test_accuracy = classifier.score(test_features, test_labels)
print(f"Training accuracy: {train_accuracy:.4f}")
print(f"Test accuracy: {test_accuracy:.4f}")
Training accuracy: 0.9832
Test accuracy: 0.9501
# Updated classify_image to display image with labels
def classify_image(image_tensor, actual_label, model, classifier, class_names):
    # Get prediction
    img = image_tensor.unsqueeze(0).to(device)
    with torch.no_grad():
        feature = model(img).cpu().numpy()
    pred = classifier.predict(feature)
    pred_class = class_names[pred[0]]
    actual_class = class_names[actual_label]

    # Unnormalize image for display
    mean = torch.tensor(Config['DATA_MEANS']).view(3, 1, 1)
    std = torch.tensor(Config['DATA_STD']).view(3, 1, 1)
    img_display = image_tensor * std + mean  # Undo normalization
    img_display = img_display.numpy().transpose(1, 2, 0)  # CHW to HWC
    img_display = np.clip(img_display, 0, 1)  # Ensure valid range

    # Display image with title
    plt.figure(figsize=(4, 4))
    plt.imshow(img_display)
    plt.title(f"Actual: {actual_class}\nPredicted: {pred_class}", fontsize=12)
    plt.axis('off')
    plt.show()

    return pred_class

# Test on a random test image
test_image, test_label = testset[random.randint(0, len(testset))]
pred_class = classify_image(test_image, test_label, dinov2_vits14, classifier, class_names)

Test results

While the classifier achieves 95% test accuracy, I’m pulling out some of its incorrect predictions to understand what confuses the classifier.

DINOv2-CIFAR10-wrong-predict-005 DINOv2-CIFAR10-wrong-predict-004 DINOv2-CIFAR10-wrong-predict-003 DINOv2-CIFAR10-wrong-predict-002 DINOv2-CIFAR10-wrong-predict-001