Mashaan blog

Self-Supervised Learning with I-JEPA

Acknowledgment

Thanks to the authors for making their code available. If I had any misunderstandings while reading the paper, I had to check the code to confirm it.

References

@InProceedings{Assran_2023_CVPR,
    author    = {Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
    title     = {Self-Supervised Learning From Images With a Joint-Embedding Predictive Architecture},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {15619-15629}
}

Main Architecture

drawings-02 001


drawings-02 002

Target encoder

drawings-02 002


Here is a printout of the architecture for the target encoder:

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
)

Context encoder

drawings-02 003


Here is a printout of the architecture for the context encoder:

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
)

Predictor

drawings-02 004

drawings-02 005


Here is a printout of the architecture for the predictor:

VisionTransformerPredictor(
  (predictor_embed): Linear(in_features=768, out_features=384, bias=True)
  (predictor_blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (predictor_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (predictor_proj): Linear(in_features=384, out_features=768, bias=True)
)

Loss

drawings-02 006


A printout of the selected patches for the masks:

INFO:root:Epoch 1
itr: 0
unsupervised_loader length 8
imgs.shape torch.Size([128, 3, 224, 224])
masks_enc[0].shape torch.Size([128, 87])
masks_enc length 1
CurrentMask[0][:] tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, 37,
        38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56,
        57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
        76, 77, 78, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91])
masks_pred[0].shape torch.Size([128, 35])
masks_pred length 4
CurrentMask[0][:] tensor([151, 152, 153, 154, 155, 156, 157, 167, 168, 169, 170, 171, 172, 173,
        183, 184, 185, 186, 187, 188, 189, 199, 200, 201, 202, 203, 204, 205,
        215, 216, 217, 218, 219, 220, 221])
CurrentMask[0][:] tensor([133, 134, 135, 136, 137, 138, 139, 149, 150, 151, 152, 153, 154, 155,
        165, 166, 167, 168, 169, 170, 171, 181, 182, 183, 184, 185, 186, 187,
        197, 198, 199, 200, 201, 202, 203])
CurrentMask[0][:] tensor([118, 119, 120, 121, 122, 123, 124, 134, 135, 136, 137, 138, 139, 140,
        150, 151, 152, 153, 154, 155, 156, 166, 167, 168, 169, 170, 171, 172,
        182, 183, 184, 185, 186, 187, 188])
CurrentMask[0][:] tensor([160, 161, 162, 163, 164, 165, 166, 176, 177, 178, 179, 180, 181, 182,
        192, 193, 194, 195, 196, 197, 198, 208, 209, 210, 211, 212, 213, 214,
        224, 225, 226, 227, 228, 229, 230])

A printout for running one iteration with displaying tensor shapes:

INFO:root:Epoch 1
itr: 0
unsupervised_loader length 8
imgs.shape torch.Size([128, 3, 224, 224])
masks_enc[0].shape torch.Size([128, 57])
masks_enc length 1
masks_pred[0].shape torch.Size([128, 42])
masks_pred length 4
inside forward_target
---------------encoder start-----------------
x.shape:  torch.Size([128, 3, 224, 224])
x.shape:  torch.Size([128, 256, 768])
pos_embed.shape:  torch.Size([1, 256, 768])
x.shape after pos_embed:  torch.Size([128, 256, 768])
x.shape after mask:  torch.Size([128, 256, 768])
x.shape after blocks:  torch.Size([128, 256, 768])
---------------encoder end-------------------
inside forward_context
---------------encoder start-----------------
x.shape:  torch.Size([128, 3, 224, 224])
x.shape:  torch.Size([128, 256, 768])
pos_embed.shape:  torch.Size([1, 256, 768])
x.shape after pos_embed:  torch.Size([128, 256, 768])
x.shape after mask:  torch.Size([128, 57, 768])
x.shape after blocks:  torch.Size([128, 57, 768])
---------------encoder end-------------------
---------------predictor start---------------
x.shape: torch.Size([128, 57, 768])
masks_x[0].shape torch.Size([128, 57])
masks_x length 1
masks[0].shape torch.Size([128, 42])
masks length 4
Batch Size: 128
x.shape after predictor_embed: torch.Size([128, 57, 384])
x_pos_embed.shape: torch.Size([128, 256, 384])
x.shape after adding positional embedding: torch.Size([128, 57, 384])
N_ctxt: 57
pos_embs.shape: torch.Size([512, 42, 384])
pred_tokens.shape: torch.Size([512, 42, 384])
x.shape after concat mask tokens: torch.Size([512, 99, 384])
x.shape after predictor_blocks: torch.Size([512, 99, 384])
x.shape after pulling mask tokens: torch.Size([512, 42, 384])
x.shape after predictor_proj: torch.Size([512, 42, 768])
---------------predictor end-----------------
INFO:root:[1,     0] loss: 0.468 masks: 57.0 42.0 [wd: 4.00e-02] [lr: 2.03e-04] [mem: 0.00e+00] (-1.0 ms)
INFO:root:[1,     0] grad_stats: [2.91e-02 1.51e-02] (1.46e-02, 3.33e-02)
itr: 1