All code snippets in this blog were created by prompting Gemini in Google Colab. Thanks to the developers who made their code available.
prompt:
load cifar10 from torchvision
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# Load CIFAR-10 dataset
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
prompt:
display images in the first batch from testloader
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(testloader)
# Correct the outdated syntax
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
prompt:
write the mlp-mixer class in pytorch
import torch
import torch.nn as nn
class MLPMixerLayer(nn.Module):
def __init__(self, in_features, token_dim, channel_dim):
super().__init__()
# The input of these layers should be of shape [batch_size, num_patches, in_features]
self.token_mixing = nn.Sequential(
nn.Linear(in_features, token_dim),
nn.GELU(),
nn.Linear(token_dim, in_features)
)
self.channel_mixing = nn.Sequential(
nn.Linear(in_features, channel_dim),
nn.GELU(),
nn.Linear(channel_dim, in_features)
)
# layer norm is applied on the last dimension.
self.layer_norm1 = nn.LayerNorm(in_features)
self.layer_norm2 = nn.LayerNorm(in_features)
def forward(self, x):
# Token mixing
residual = x
# input here should be [batch_size, num_patches, in_features]
x = x.permute(0,2,1) # shape [batch_size, in_features, num_patches]
x = self.layer_norm1(x)
x = x.permute(0,2,1) # shape [batch_size, num_patches, in_features]
x = self.token_mixing(x)
x = x + residual
# Channel Mixing
residual = x
x = self.layer_norm2(x)
x = self.channel_mixing(x)
x = x + residual
return x
class MLPMixer(nn.Module):
def __init__(self, image_size, patch_size, num_classes, dim, depth, token_dim, channel_dim):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
self.num_patches = (image_size // patch_size) ** 2
self.patch_size = patch_size
self.to_patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
# in_features here is the dim, since the output of the Conv2d layer has a last dimension of dim
self.mixer_blocks = nn.Sequential(*[MLPMixerLayer(dim, token_dim, channel_dim) for _ in range(depth)])
self.layer_norm = nn.LayerNorm(dim)
self.mlp_head = nn.Sequential(
nn.Linear(dim, num_classes)
)
def forward(self, x):
patches = self.to_patch_embedding(x)
batch_size, num_channels, _, _ = patches.shape
patches = patches.permute(0, 2, 3, 1)
# the shape should be batch_size, num_patches, dim
patches = patches.view(batch_size, self.num_patches, -1)
x = self.mixer_blocks(patches)
x = self.layer_norm(x)
x = torch.mean(x, dim=1) # Global average pooling
x = self.mlp_head(x)
return x
prompt:
initialize the model using MLPMixer class
image_size = 32
patch_size = 4
num_classes = 10
dim = 64
depth = 2
token_dim = 32
channel_dim = 256
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MLPMixer(image_size, patch_size, num_classes, dim, depth, token_dim, channel_dim).to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print("Model initialized.")
prompt:
write the the training loop, use 10 epochs
num_epochs = 10
for epoch in range(num_epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 1.904
[1, 4000] loss: 1.652
[1, 6000] loss: 1.519
[1, 8000] loss: 1.442
[1, 10000] loss: 1.377
[1, 12000] loss: 1.342
[2, 2000] loss: 1.261
[2, 4000] loss: 1.263
[2, 6000] loss: 1.221
[2, 8000] loss: 1.181
[2, 10000] loss: 1.167
[2, 12000] loss: 1.174
[3, 2000] loss: 1.118
[3, 4000] loss: 1.097
[3, 6000] loss: 1.094
[3, 8000] loss: 1.078
[3, 10000] loss: 1.073
[3, 12000] loss: 1.059
[4, 2000] loss: 1.036
[4, 4000] loss: 1.016
[4, 6000] loss: 1.008
[4, 8000] loss: 1.008
[4, 10000] loss: 1.024
[4, 12000] loss: 1.008
[5, 2000] loss: 0.951
[5, 4000] loss: 0.978
[5, 6000] loss: 0.967
[5, 8000] loss: 0.957
[5, 10000] loss: 0.968
[5, 12000] loss: 0.966
[6, 2000] loss: 0.939
[6, 4000] loss: 0.924
[6, 6000] loss: 0.923
[6, 8000] loss: 0.912
[6, 10000] loss: 0.931
[6, 12000] loss: 0.920
[7, 2000] loss: 0.887
[7, 4000] loss: 0.877
[7, 6000] loss: 0.896
[7, 8000] loss: 0.889
[7, 10000] loss: 0.910
[7, 12000] loss: 0.883
[8, 2000] loss: 0.828
[8, 4000] loss: 0.845
[8, 6000] loss: 0.861
[8, 8000] loss: 0.864
[8, 10000] loss: 0.843
[8, 12000] loss: 0.890
[9, 2000] loss: 0.811
[9, 4000] loss: 0.829
[9, 6000] loss: 0.832
[9, 8000] loss: 0.839
[9, 10000] loss: 0.843
[9, 12000] loss: 0.828
[10, 2000] loss: 0.777
[10, 4000] loss: 0.795
[10, 6000] loss: 0.815
[10, 8000] loss: 0.825
[10, 10000] loss: 0.810
[10, 12000] loss: 0.803
Finished Training
prompt:
pull a random batch from testloader and test it using the trained model
# Get a random batch from the testloader
dataiter = iter(testloader)
images, labels = next(dataiter)
images = images.to(device)
labels = labels.to(device)
# Test the model
outputs = model(images)
# Get predicted labels
_, predicted = torch.max(outputs, 1)
# Print predictions and actual labels
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))
print('Actual: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
Predicted: cat ship ship plane
Actual: cat ship ship plane
prompt:
test all batches in testloader and create a confusion matrix
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
# Initialize the confusion matrix
confusion_mat = np.zeros((10, 10), dtype=int)
# Iterate over all batches in the testloader
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
# Update the confusion matrix
for i in range(len(labels)):
confusion_mat[labels[i], predicted[i]] += 1
# Plot the confusion matrix
plt.figure(figsize=(10, 8))
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = confusion_mat.max() / 2.
for i, j in np.ndindex(confusion_mat.shape):
plt.text(j, i, confusion_mat[i, j],
horizontalalignment="center",
color="white" if confusion_mat[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
prompt:
print the test accuracy for testloader
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
# calculate outputs by running images through the network
outputs = model(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 67 %