Mashaan blog

MLP-Mixer in jax

Acknowledgment:

References

@article{tolstikhin2021mixer,
 title    = {MLP-Mixer: An all-MLP Architecture for Vision},
 author   = {Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, Thomas and Yung, Jessica and Steiner, Andreas and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
 journal  = {arXiv preprint arXiv:2105.01601},
 year     = {2021}
}

Prepare libraries

!pip install einops ml_collections
import os
from absl import logging
import matplotlib.pyplot as plt
import numpy as np
import random
from typing import Any, Optional
import ml_collections

import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import einops

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data

import tqdm
# Shows the number of available devices.
# In a CPU/GPU runtime this will be a single device.
# In a TPU runtime this will be 8 cores.
jax.local_devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Config = {
    'NUM_CLASSES': 10,
    'BATCH_SIZE': 128,
    'NUM_EPOCHS': 100,
    'LR': 0.001,
    'WIDTH': 32,
    'HEIGHT': 32,
    'DATA_MEANS': np.array([0.49139968, 0.48215841, 0.44653091]), # mean of the CIFAR dataset, used for normalization
    'DATA_STD': np.array([0.24703223, 0.24348513, 0.26158784]),   # standard deviation of the CIFAR dataset, used for normalization
    'CROP_SCALES': (0.8, 1.0),
    'CROP_RATIO': (0.9, 1.1),
    'SEED': 42,
}

Import cifar-10

# A helper function that normalizes the images between the values specified by the hyper-parameters.
def image_to_numpy(img):
  img = np.array(img, dtype=np.float32)
  img = (img / 255. - Config['DATA_MEANS']) / Config['DATA_STD']
  return img
# A helper function that converts batches into numpy arrays instead of the default option which is torch tensors
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple, list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)
# images in the test set will only be converted into numpy arrays
test_transform = image_to_numpy
# images in the train set will be randomly flipped, cropped, and then converted to numpy arrays
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((Config['HEIGHT'], Config['WIDTH']), scale=Config['CROP_SCALES'], ratio=Config['CROP_RATIO']),
    image_to_numpy
])

# Validation set should not use train_transform.
train_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(Config['SEED']))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(Config['SEED']))
test_set = torchvision.datasets.CIFAR10('data', train=False, transform=test_transform, download=True)

train_data_loader = torch.utils.data.DataLoader(
    train_set, batch_size=Config['BATCH_SIZE'], shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_data_loader = torch.utils.data.DataLoader(
    val_set, batch_size=Config['BATCH_SIZE'], shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_data_loader = torch.utils.data.DataLoader(
    test_set, batch_size=Config['BATCH_SIZE'], shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|██████████| 170M/170M [00:04<00:00, 36.0MB/s]
Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified
imgs, _ = next(iter(train_data_loader))
print("Batch mean", imgs.mean(axis=(0,1,2)))
print("Batch std", imgs.std(axis=(0,1,2)))
Batch mean [-0.03558055 -0.03113859 -0.0337262 ]
Batch std [0.92609125 0.9158829  0.94482037]
print(f'size of images in the first train batch: {next(iter(train_data_loader))[0].shape}')
print(f'type of images in the first train batch: {next(iter(train_data_loader))[0].dtype}')
print(f'size of labels in the first train batch: {next(iter(train_data_loader))[1].shape}')
print(f'type of labels in the first train batch: {next(iter(train_data_loader))[1].dtype}')
size of images in the first train batch: (128, 32, 32, 3)
type of images in the first train batch: float64
size of labels in the first train batch: (128,)
type of labels in the first train batch: int64

Displaying images from cifar-10

# Helper functions for images.

labelnames = dict(
  # https://www.cs.toronto.edu/~kriz/cifar.html
  cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),

def make_label_getter(dataset):
  """Returns a function converting label indices to names."""
  def getter(label):
    if dataset in labelnames:
      return labelnames[dataset][label]
    return f'label={label}'
  return getter
# Number of images to display
num_images = 10

# Create a figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))

# Get the label names for CIFAR-10
get_label_name = make_label_getter('cifar10')

# Iterate through the first 10 images and display them
for i in range(num_images):
  img, label = train_data_loader.dataset[i]
  axes[i].imshow(img * Config['DATA_STD'] + Config['DATA_MEANS'])
  axes[i].set_title(get_label_name(label))
  axes[i].axis('off')

plt.tight_layout()
plt.show()

image

# Create a figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))

# Get the label names for CIFAR-10
get_label_name = make_label_getter('cifar10')

# Select 10 random indices
random_indices = random.sample(range(len(train_data_loader.dataset)), num_images)

# Iterate through the random images and display them
for i, index in enumerate(random_indices):
    img, label = train_data_loader.dataset[index]
    axes[i].imshow(img * Config['DATA_STD'] + Config['DATA_MEANS'])
    axes[i].set_title(get_label_name(label))
    axes[i].axis('off')

plt.tight_layout()
plt.show()

image

MLP-Mixer Architecture

class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)


class MlpMixer(nn.Module):
  """Mixer architecture."""
  patches: Any
  num_classes: int
  num_blocks: int
  hidden_dim: int
  tokens_mlp_dim: int
  channels_mlp_dim: int
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):
    del train
    x = nn.Conv(self.hidden_dim, self.patches.size,
                strides=self.patches.size, name='stem')(inputs)

    x = einops.rearrange(x, 'n h w c -> n (h w) c')
    for _ in range(self.num_blocks):
      x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
    x = nn.LayerNorm(name='pre_head_layer_norm')(x)
    x = jnp.mean(x, axis=1)
    if self.num_classes:
      x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                   name='head')(x)
    return x

Initializing the model

@jax.jit
def train_step(state, batch):
  def loss_fn(params):
    logits = state.apply_fn(params, batch['image'], train=True)
    one_hot_labels = jax.nn.one_hot(batch['label'], logits.shape[1])
    loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss, logits

def create_train_state(module, rng, learning_rate, image_shape):
  params = module.init(rng, jnp.zeros((1, *image_shape)), train=True)
  tx = optax.adam(learning_rate)
  return train_state.TrainState.create(apply_fn=module.apply, params=params, tx=tx)
# Define data batch shape.
image_shape = next(iter(train_data_loader))[0].shape[1:]

# Define the model
model = MlpMixer(
    patches=ml_collections.ConfigDict({'size': (4, 4)}),
    num_classes=Config['NUM_CLASSES'],
    num_blocks=8,
    hidden_dim=192,
    tokens_mlp_dim=96,
    channels_mlp_dim=768)

# Create the train state.
rng = jax.random.PRNGKey(0)
state = create_train_state(model, rng, Config['LR'], image_shape)

Training loop

# Training loop
num_epochs = Config['NUM_EPOCHS']
for epoch in range(num_epochs):
  # Training
  train_losses = []
  for batch in tqdm.tqdm(train_data_loader):
    batch = {'image': batch[0], 'label': batch[1]}
    state, loss, _ = train_step(state, batch)
    train_losses.append(loss)

  # Print average training loss
  avg_train_loss = np.mean(train_losses)
  print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss}")

  # Validation (optional, but recommended)
  val_losses = []
  for batch in tqdm.tqdm(val_data_loader):
    batch = {'image': batch[0], 'label': batch[1]}
    # Use apply_fn for validation to avoid gradient updates
    logits = state.apply_fn(state.params, batch['image'], train=False)
    one_hot_labels = jax.nn.one_hot(batch['label'], logits.shape[1])
    loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
    val_losses.append(loss)

  # Print average validation loss
  avg_val_loss = np.mean(val_losses)
  print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss}")

Running one batch to test the tensors shapes

I modified the MLP-Mixer code to add some print statments to check the tensors shape. Here’s the modified code:

class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    print("-----------in MlpBlock-----------")
    print(f"x.shape before MlpBlock: {x.shape}")
    y = nn.Dense(self.mlp_dim)(x)
    print(f"y.shape after Dense: {y.shape}")
    y = nn.gelu(y)
    res = nn.Dense(x.shape[-1])(y)
    print(f"MlpBlock returned shape: {nn.Dense(x.shape[-1])(y).shape}")
    print("-----------out MlpBlock-----------")
    return res


class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    print("-----------in MixerBlock-----------")
    print(f"x.shape before MixerBlock: {x.shape}")
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    print(f"y.shape after jnp.swapaxes: {y.shape}")
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    print(f"y.shape after MlpBlock: {y.shape}")
    y = jnp.swapaxes(y, 1, 2)
    print(f"y.shape after jnp.swapaxes: {y.shape}")
    x = x + y
    y = nn.LayerNorm()(x)
    res = x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
    print(f"MixerBlock returned shape: {res.shape}")
    print("-----------out MixerBlock-----------")
    return res


class MlpMixer(nn.Module):
  """Mixer architecture."""
  patches: Any
  num_classes: int
  num_blocks: int
  hidden_dim: int
  tokens_mlp_dim: int
  channels_mlp_dim: int
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):
    del train
    print(f"inputs.shape: {inputs.shape}")
    x = nn.Conv(self.hidden_dim, self.patches.size,
                strides=self.patches.size, name='stem')(inputs)

    print(f"x.shape after stem: {x.shape}")
    x = einops.rearrange(x, 'n h w c -> n (h w) c')
    print(f"x.shape after einops.rearrange: {x.shape}")
    for i in range(self.num_blocks):
      print(f"-----------block {i}-----------")
      print(f"x.shape before MixerBlock: {x.shape}")
      x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
      print(f"x.shape after MixerBlock: {x.shape}")
    x = nn.LayerNorm(name='pre_head_layer_norm')(x)
    print(f"x.shape after pre_head_layer_norm: {x.shape}")
    x = jnp.mean(x, axis=1)
    print(f"x.shape after jnp.mean: {x.shape}")
    if self.num_classes:
      x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                   name='head')(x)
    print(f"x.shape after head: {x.shape}")
    return x

Here’s the output of running one batch through the model with print statements:

(128, 32, 32, 3)
x.shape after stem: (128, 8, 8, 192)
x.shape after einops.rearrange: (128, 64, 192)
-----------block 0-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 1-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 2-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 3-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 4-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 5-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 6-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 7-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
x.shape after pre_head_layer_norm: (128, 64, 192)
x.shape after jnp.mean: (128, 192)
x.shape after head: (128, 10)

IMG_2653

IMG_2654

Training for 100 epochs

100%|██████████| 351/351 [00:18<00:00, 19.38it/s]
Epoch 1/100, Train Loss: 1.7528928518295288
100%|██████████| 40/40 [00:15<00:00,  2.58it/s]
Epoch 1/100, Validation Loss: 1.4253473281860352
100%|██████████| 351/351 [00:07<00:00, 44.37it/s]
Epoch 2/100, Train Loss: 1.3267101049423218
100%|██████████| 40/40 [00:08<00:00,  4.48it/s]
Epoch 2/100, Validation Loss: 1.1775858402252197
100%|██████████| 351/351 [00:07<00:00, 45.42it/s]
Epoch 3/100, Train Loss: 1.1273109912872314
100%|██████████| 40/40 [00:08<00:00,  4.70it/s]
Epoch 3/100, Validation Loss: 1.0781110525131226
100%|██████████| 351/351 [00:08<00:00, 43.53it/s]
Epoch 4/100, Train Loss: 0.9951978921890259
100%|██████████| 40/40 [00:08<00:00,  4.77it/s]
Epoch 4/100, Validation Loss: 0.9886428713798523
100%|██████████| 351/351 [00:07<00:00, 45.54it/s]
Epoch 5/100, Train Loss: 0.896670937538147
100%|██████████| 40/40 [00:08<00:00,  4.74it/s]
Epoch 5/100, Validation Loss: 0.8942521810531616
...
...
...
100%|██████████| 351/351 [00:07<00:00, 47.66it/s]
Epoch 96/100, Train Loss: 0.03413056954741478
100%|██████████| 40/40 [00:08<00:00,  4.95it/s]
Epoch 96/100, Validation Loss: 1.4042478799819946
100%|██████████| 351/351 [00:07<00:00, 47.08it/s]
Epoch 97/100, Train Loss: 0.036916933953762054
100%|██████████| 40/40 [00:08<00:00,  4.93it/s]
Epoch 97/100, Validation Loss: 1.2745997905731201
100%|██████████| 351/351 [00:07<00:00, 47.86it/s]
Epoch 98/100, Train Loss: 0.03660571575164795
100%|██████████| 40/40 [00:08<00:00,  4.76it/s]
Epoch 98/100, Validation Loss: 1.2667434215545654
100%|██████████| 351/351 [00:07<00:00, 47.73it/s]
Epoch 99/100, Train Loss: 0.029568077996373177
100%|██████████| 40/40 [00:08<00:00,  4.86it/s]
Epoch 99/100, Validation Loss: 1.3174177408218384
100%|██████████| 351/351 [00:07<00:00, 47.50it/s]
Epoch 100/100, Train Loss: 0.03492546081542969
100%|██████████| 40/40 [00:08<00:00,  4.80it/s]
Epoch 100/100, Validation Loss: 1.397871971130371

Testing with random images

# Number of random images to pull
num_images = 10

# Select random indices from the validation set
random_indices = random.sample(range(len(test_data_loader.dataset)), num_images)

# Create figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))

# Get label names function
get_label_name = make_label_getter('cifar10')


for i, index in enumerate(random_indices):
    img, label = test_data_loader.dataset[index]

    # Display the image
    axes[i].imshow(img * Config['DATA_STD'] + Config['DATA_MEANS'])
    axes[i].axis('off')

    # Get model prediction
    image_input = jnp.expand_dims(img, axis=0) # Add batch dimension
    logits = state.apply_fn(state.params, image_input, train=False)
    predicted_label = jnp.argmax(logits, axis=1)[0]

    # Display the true and predicted labels
    axes[i].set_title(f"True: {get_label_name(label)}\nPred: {get_label_name(predicted_label)}")

plt.tight_layout()
plt.show()

image

Plotting the confusion matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

# Initialize lists to store true and predicted labels
true_labels = []
predicted_labels = []

# Iterate through the test data loader
for batch in tqdm.tqdm(test_data_loader):
    images, labels = batch

    # Get model predictions
    image_input = jnp.array(images)
    logits = state.apply_fn(state.params, image_input, train=False)
    predicted_batch_labels = jnp.argmax(logits, axis=1)

    # Append true and predicted labels to lists
    true_labels.extend(labels.tolist())
    predicted_labels.extend(predicted_batch_labels.tolist())

# Calculate overall accuracy
correct_predictions = sum(1 for true, pred in zip(true_labels, predicted_labels) if true == pred)
total_predictions = len(true_labels)
accuracy = correct_predictions / total_predictions
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Compute the confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Plot the confusion matrix using seaborn
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=labelnames['cifar10'], yticklabels=labelnames['cifar10'])
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

Training for 10 epochs

epochs-10

Training for 100 epochs

epochs-100

Implementation details

optax.softmax_cross_entropy_with_integer_labels

optax.softmax_cross_entropy_with_integer_labels computes softmax cross entropy between the logits and integer labels. So you don’t have to convert the labels to one-hot encoded. When I used this function I got NaN as a loss. When I changed the labels to one-hot encoded it fixed it.

Importing cifar-10 from torchvision instead of tensorflow_datasets

When I tried to import cifar-10 dataset from tensorflow_datasets, it was causing errors that forces the colab notebook to restart. So I imported it from torchvision and change the data types to be compatible with jax.