@article{tolstikhin2021mixer,
title = {MLP-Mixer: An all-MLP Architecture for Vision},
author = {Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, Thomas and Yung, Jessica and Steiner, Andreas and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
journal = {arXiv preprint arXiv:2105.01601},
year = {2021}
}
!pip install einops ml_collections
import os
from absl import logging
import matplotlib.pyplot as plt
import numpy as np
import random
from typing import Any, Optional
import ml_collections
import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import einops
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import tqdm
# Shows the number of available devices.
# In a CPU/GPU runtime this will be a single device.
# In a TPU runtime this will be 8 cores.
jax.local_devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Config = {
'NUM_CLASSES': 10,
'BATCH_SIZE': 128,
'NUM_EPOCHS': 100,
'LR': 0.001,
'WIDTH': 32,
'HEIGHT': 32,
'DATA_MEANS': np.array([0.49139968, 0.48215841, 0.44653091]), # mean of the CIFAR dataset, used for normalization
'DATA_STD': np.array([0.24703223, 0.24348513, 0.26158784]), # standard deviation of the CIFAR dataset, used for normalization
'CROP_SCALES': (0.8, 1.0),
'CROP_RATIO': (0.9, 1.1),
'SEED': 42,
}
# A helper function that normalizes the images between the values specified by the hyper-parameters.
def image_to_numpy(img):
img = np.array(img, dtype=np.float32)
img = (img / 255. - Config['DATA_MEANS']) / Config['DATA_STD']
return img
# A helper function that converts batches into numpy arrays instead of the default option which is torch tensors
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
# images in the test set will only be converted into numpy arrays
test_transform = image_to_numpy
# images in the train set will be randomly flipped, cropped, and then converted to numpy arrays
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((Config['HEIGHT'], Config['WIDTH']), scale=Config['CROP_SCALES'], ratio=Config['CROP_RATIO']),
image_to_numpy
])
# Validation set should not use train_transform.
train_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(Config['SEED']))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(Config['SEED']))
test_set = torchvision.datasets.CIFAR10('data', train=False, transform=test_transform, download=True)
train_data_loader = torch.utils.data.DataLoader(
train_set, batch_size=Config['BATCH_SIZE'], shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_data_loader = torch.utils.data.DataLoader(
val_set, batch_size=Config['BATCH_SIZE'], shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_data_loader = torch.utils.data.DataLoader(
test_set, batch_size=Config['BATCH_SIZE'], shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|██████████| 170M/170M [00:04<00:00, 36.0MB/s]
Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified
imgs, _ = next(iter(train_data_loader))
print("Batch mean", imgs.mean(axis=(0,1,2)))
print("Batch std", imgs.std(axis=(0,1,2)))
Batch mean [-0.03558055 -0.03113859 -0.0337262 ]
Batch std [0.92609125 0.9158829 0.94482037]
print(f'size of images in the first train batch: {next(iter(train_data_loader))[0].shape}')
print(f'type of images in the first train batch: {next(iter(train_data_loader))[0].dtype}')
print(f'size of labels in the first train batch: {next(iter(train_data_loader))[1].shape}')
print(f'type of labels in the first train batch: {next(iter(train_data_loader))[1].dtype}')
size of images in the first train batch: (128, 32, 32, 3)
type of images in the first train batch: float64
size of labels in the first train batch: (128,)
type of labels in the first train batch: int64
# Helper functions for images.
labelnames = dict(
# https://www.cs.toronto.edu/~kriz/cifar.html
cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
def make_label_getter(dataset):
"""Returns a function converting label indices to names."""
def getter(label):
if dataset in labelnames:
return labelnames[dataset][label]
return f'label={label}'
return getter
# Number of images to display
num_images = 10
# Create a figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))
# Get the label names for CIFAR-10
get_label_name = make_label_getter('cifar10')
# Iterate through the first 10 images and display them
for i in range(num_images):
img, label = train_data_loader.dataset[i]
axes[i].imshow(img * Config['DATA_STD'] + Config['DATA_MEANS'])
axes[i].set_title(get_label_name(label))
axes[i].axis('off')
plt.tight_layout()
plt.show()
# Create a figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))
# Get the label names for CIFAR-10
get_label_name = make_label_getter('cifar10')
# Select 10 random indices
random_indices = random.sample(range(len(train_data_loader.dataset)), num_images)
# Iterate through the random images and display them
for i, index in enumerate(random_indices):
img, label = train_data_loader.dataset[index]
axes[i].imshow(img * Config['DATA_STD'] + Config['DATA_MEANS'])
axes[i].set_title(get_label_name(label))
axes[i].axis('off')
plt.tight_layout()
plt.show()
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
model_name: Optional[str] = None
@nn.compact
def __call__(self, inputs, *, train):
del train
x = nn.Conv(self.hidden_dim, self.patches.size,
strides=self.patches.size, name='stem')(inputs)
x = einops.rearrange(x, 'n h w c -> n (h w) c')
for _ in range(self.num_blocks):
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
x = jnp.mean(x, axis=1)
if self.num_classes:
x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
name='head')(x)
return x
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn(params, batch['image'], train=True)
one_hot_labels = jax.nn.one_hot(batch['label'], logits.shape[1])
loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss, logits
def create_train_state(module, rng, learning_rate, image_shape):
params = module.init(rng, jnp.zeros((1, *image_shape)), train=True)
tx = optax.adam(learning_rate)
return train_state.TrainState.create(apply_fn=module.apply, params=params, tx=tx)
# Define data batch shape.
image_shape = next(iter(train_data_loader))[0].shape[1:]
# Define the model
model = MlpMixer(
patches=ml_collections.ConfigDict({'size': (4, 4)}),
num_classes=Config['NUM_CLASSES'],
num_blocks=8,
hidden_dim=192,
tokens_mlp_dim=96,
channels_mlp_dim=768)
# Create the train state.
rng = jax.random.PRNGKey(0)
state = create_train_state(model, rng, Config['LR'], image_shape)
# Training loop
num_epochs = Config['NUM_EPOCHS']
for epoch in range(num_epochs):
# Training
train_losses = []
for batch in tqdm.tqdm(train_data_loader):
batch = {'image': batch[0], 'label': batch[1]}
state, loss, _ = train_step(state, batch)
train_losses.append(loss)
# Print average training loss
avg_train_loss = np.mean(train_losses)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss}")
# Validation (optional, but recommended)
val_losses = []
for batch in tqdm.tqdm(val_data_loader):
batch = {'image': batch[0], 'label': batch[1]}
# Use apply_fn for validation to avoid gradient updates
logits = state.apply_fn(state.params, batch['image'], train=False)
one_hot_labels = jax.nn.one_hot(batch['label'], logits.shape[1])
loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
val_losses.append(loss)
# Print average validation loss
avg_val_loss = np.mean(val_losses)
print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss}")
I modified the MLP-Mixer code to add some print statments to check the tensors shape. Here’s the modified code:
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
print("-----------in MlpBlock-----------")
print(f"x.shape before MlpBlock: {x.shape}")
y = nn.Dense(self.mlp_dim)(x)
print(f"y.shape after Dense: {y.shape}")
y = nn.gelu(y)
res = nn.Dense(x.shape[-1])(y)
print(f"MlpBlock returned shape: {nn.Dense(x.shape[-1])(y).shape}")
print("-----------out MlpBlock-----------")
return res
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
print("-----------in MixerBlock-----------")
print(f"x.shape before MixerBlock: {x.shape}")
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
print(f"y.shape after jnp.swapaxes: {y.shape}")
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
print(f"y.shape after MlpBlock: {y.shape}")
y = jnp.swapaxes(y, 1, 2)
print(f"y.shape after jnp.swapaxes: {y.shape}")
x = x + y
y = nn.LayerNorm()(x)
res = x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
print(f"MixerBlock returned shape: {res.shape}")
print("-----------out MixerBlock-----------")
return res
class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
model_name: Optional[str] = None
@nn.compact
def __call__(self, inputs, *, train):
del train
print(f"inputs.shape: {inputs.shape}")
x = nn.Conv(self.hidden_dim, self.patches.size,
strides=self.patches.size, name='stem')(inputs)
print(f"x.shape after stem: {x.shape}")
x = einops.rearrange(x, 'n h w c -> n (h w) c')
print(f"x.shape after einops.rearrange: {x.shape}")
for i in range(self.num_blocks):
print(f"-----------block {i}-----------")
print(f"x.shape before MixerBlock: {x.shape}")
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
print(f"x.shape after MixerBlock: {x.shape}")
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
print(f"x.shape after pre_head_layer_norm: {x.shape}")
x = jnp.mean(x, axis=1)
print(f"x.shape after jnp.mean: {x.shape}")
if self.num_classes:
x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
name='head')(x)
print(f"x.shape after head: {x.shape}")
return x
Here’s the output of running one batch through the model with print statements:
(128, 32, 32, 3)
x.shape after stem: (128, 8, 8, 192)
x.shape after einops.rearrange: (128, 64, 192)
-----------block 0-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 1-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 2-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 3-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 4-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 5-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 6-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
-----------block 7-----------
x.shape before MixerBlock: (128, 64, 192)
-----------in MixerBlock-----------
x.shape before MixerBlock: (128, 64, 192)
y.shape after jnp.swapaxes: (128, 192, 64)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 192, 64)
y.shape after Dense: (128, 192, 96)
MlpBlock returned shape: (128, 192, 64)
-----------out MlpBlock-----------
y.shape after MlpBlock: (128, 192, 64)
y.shape after jnp.swapaxes: (128, 64, 192)
-----------in MlpBlock-----------
x.shape before MlpBlock: (128, 64, 192)
y.shape after Dense: (128, 64, 768)
MlpBlock returned shape: (128, 64, 192)
-----------out MlpBlock-----------
MixerBlock returned shape: (128, 64, 192)
-----------out MixerBlock-----------
x.shape after MixerBlock: (128, 64, 192)
x.shape after pre_head_layer_norm: (128, 64, 192)
x.shape after jnp.mean: (128, 192)
x.shape after head: (128, 10)
100%|██████████| 351/351 [00:18<00:00, 19.38it/s]
Epoch 1/100, Train Loss: 1.7528928518295288
100%|██████████| 40/40 [00:15<00:00, 2.58it/s]
Epoch 1/100, Validation Loss: 1.4253473281860352
100%|██████████| 351/351 [00:07<00:00, 44.37it/s]
Epoch 2/100, Train Loss: 1.3267101049423218
100%|██████████| 40/40 [00:08<00:00, 4.48it/s]
Epoch 2/100, Validation Loss: 1.1775858402252197
100%|██████████| 351/351 [00:07<00:00, 45.42it/s]
Epoch 3/100, Train Loss: 1.1273109912872314
100%|██████████| 40/40 [00:08<00:00, 4.70it/s]
Epoch 3/100, Validation Loss: 1.0781110525131226
100%|██████████| 351/351 [00:08<00:00, 43.53it/s]
Epoch 4/100, Train Loss: 0.9951978921890259
100%|██████████| 40/40 [00:08<00:00, 4.77it/s]
Epoch 4/100, Validation Loss: 0.9886428713798523
100%|██████████| 351/351 [00:07<00:00, 45.54it/s]
Epoch 5/100, Train Loss: 0.896670937538147
100%|██████████| 40/40 [00:08<00:00, 4.74it/s]
Epoch 5/100, Validation Loss: 0.8942521810531616
...
...
...
100%|██████████| 351/351 [00:07<00:00, 47.66it/s]
Epoch 96/100, Train Loss: 0.03413056954741478
100%|██████████| 40/40 [00:08<00:00, 4.95it/s]
Epoch 96/100, Validation Loss: 1.4042478799819946
100%|██████████| 351/351 [00:07<00:00, 47.08it/s]
Epoch 97/100, Train Loss: 0.036916933953762054
100%|██████████| 40/40 [00:08<00:00, 4.93it/s]
Epoch 97/100, Validation Loss: 1.2745997905731201
100%|██████████| 351/351 [00:07<00:00, 47.86it/s]
Epoch 98/100, Train Loss: 0.03660571575164795
100%|██████████| 40/40 [00:08<00:00, 4.76it/s]
Epoch 98/100, Validation Loss: 1.2667434215545654
100%|██████████| 351/351 [00:07<00:00, 47.73it/s]
Epoch 99/100, Train Loss: 0.029568077996373177
100%|██████████| 40/40 [00:08<00:00, 4.86it/s]
Epoch 99/100, Validation Loss: 1.3174177408218384
100%|██████████| 351/351 [00:07<00:00, 47.50it/s]
Epoch 100/100, Train Loss: 0.03492546081542969
100%|██████████| 40/40 [00:08<00:00, 4.80it/s]
Epoch 100/100, Validation Loss: 1.397871971130371
# Number of random images to pull
num_images = 10
# Select random indices from the validation set
random_indices = random.sample(range(len(test_data_loader.dataset)), num_images)
# Create figure and axes
fig, axes = plt.subplots(1, num_images, figsize=(20, 5))
# Get label names function
get_label_name = make_label_getter('cifar10')
for i, index in enumerate(random_indices):
img, label = test_data_loader.dataset[index]
# Display the image
axes[i].imshow(img * Config['DATA_STD'] + Config['DATA_MEANS'])
axes[i].axis('off')
# Get model prediction
image_input = jnp.expand_dims(img, axis=0) # Add batch dimension
logits = state.apply_fn(state.params, image_input, train=False)
predicted_label = jnp.argmax(logits, axis=1)[0]
# Display the true and predicted labels
axes[i].set_title(f"True: {get_label_name(label)}\nPred: {get_label_name(predicted_label)}")
plt.tight_layout()
plt.show()
from sklearn.metrics import confusion_matrix
import seaborn as sns
# Initialize lists to store true and predicted labels
true_labels = []
predicted_labels = []
# Iterate through the test data loader
for batch in tqdm.tqdm(test_data_loader):
images, labels = batch
# Get model predictions
image_input = jnp.array(images)
logits = state.apply_fn(state.params, image_input, train=False)
predicted_batch_labels = jnp.argmax(logits, axis=1)
# Append true and predicted labels to lists
true_labels.extend(labels.tolist())
predicted_labels.extend(predicted_batch_labels.tolist())
# Calculate overall accuracy
correct_predictions = sum(1 for true, pred in zip(true_labels, predicted_labels) if true == pred)
total_predictions = len(true_labels)
accuracy = correct_predictions / total_predictions
print(f"Test Accuracy: {accuracy * 100:.2f}%")
# Compute the confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)
# Plot the confusion matrix using seaborn
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=labelnames['cifar10'], yticklabels=labelnames['cifar10'])
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()
optax.softmax_cross_entropy_with_integer_labels
optax.softmax_cross_entropy_with_integer_labels
computes softmax cross entropy between the logits and integer labels. So you don’t have to convert the labels to one-hot encoded. When I used this function I got NaN
as a loss. When I changed the labels to one-hot encoded it fixed it.
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label']).mean()
one_hot_labels = jax.nn.one_hot(batch['label'], logits.shape[1])
loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
torchvision
instead of tensorflow_datasets
When I tried to import cifar-10 dataset from tensorflow_datasets
, it was causing errors that forces the colab notebook to restart. So I imported it from torchvision
and change the data types to be compatible with jax.