I borrowed some code from Swin Transformer github repository.
@inproceedings{liu2021Swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2021}
}
To run the notebook, I copied classes and functions from:
To save space in this post, I’m only showing the function/class name.
class Mlp(nn.Module):
def window_partition(x, window_size):
def window_reverse(windows, window_size, H, W):
class WindowAttention(nn.Module):
class SwinTransformerBlock(nn.Module):
class PatchMerging(nn.Module):
class BasicLayer(nn.Module):
class PatchEmbed(nn.Module):
class SwinTransformer(nn.Module):
model = SwinTransformer(img_size=96,
patch_size=4,
in_chans=3,
num_classes=10,
embed_dim=48,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=6,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
drop_path_rate=0.1,
ape=False,
norm_layer=nn.LayerNorm,
patch_norm=True,
use_checkpoint=False,
fused_window_process=False)
# Transfer to GPU
model.to(device)
# setup the loss function
criterion = nn.CrossEntropyLoss()
# setup the optimizer with the learning rate
model_optimizer = optim.AdamW(model.parameters(), lr=5e-4)
model
SwinTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 48, kernel_size=(4, 4), stride=(4, 4))
(norm): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
)
(pos_drop): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0): BasicLayer(
dim=48, input_resolution=(24, 24), depth=2
(blocks): ModuleList(
(0): SwinTransformerBlock(
dim=48, input_resolution=(24, 24), num_heads=3, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=48, window_size=(6, 6), num_heads=3
(qkv): Linear(in_features=48, out_features=144, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=48, out_features=48, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): Identity()
(norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=48, out_features=192, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=192, out_features=48, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
dim=48, input_resolution=(24, 24), num_heads=3, window_size=6, shift_size=3, mlp_ratio=4
(norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=48, window_size=(6, 6), num_heads=3
(qkv): Linear(in_features=48, out_features=144, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=48, out_features=48, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.009)
(norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=48, out_features=192, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=192, out_features=48, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(downsample): PatchMerging(
input_resolution=(24, 24), dim=48
(reduction): Linear(in_features=192, out_features=96, bias=False)
(norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
)
)
(1): BasicLayer(
dim=96, input_resolution=(12, 12), depth=2
(blocks): ModuleList(
(0): SwinTransformerBlock(
dim=96, input_resolution=(12, 12), num_heads=6, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=96, window_size=(6, 6), num_heads=6
(qkv): Linear(in_features=96, out_features=288, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=96, out_features=96, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.018)
(norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=96, out_features=384, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=384, out_features=96, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
dim=96, input_resolution=(12, 12), num_heads=6, window_size=6, shift_size=3, mlp_ratio=4
(norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=96, window_size=(6, 6), num_heads=6
(qkv): Linear(in_features=96, out_features=288, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=96, out_features=96, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.027)
(norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=96, out_features=384, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=384, out_features=96, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(downsample): PatchMerging(
input_resolution=(12, 12), dim=96
(reduction): Linear(in_features=384, out_features=192, bias=False)
(norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
)
(2): BasicLayer(
dim=192, input_resolution=(6, 6), depth=6
(blocks): ModuleList(
(0): SwinTransformerBlock(
dim=192, input_resolution=(6, 6), num_heads=12, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=192, window_size=(6, 6), num_heads=12
(qkv): Linear(in_features=192, out_features=576, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.036)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
dim=192, input_resolution=(6, 6), num_heads=12, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=192, window_size=(6, 6), num_heads=12
(qkv): Linear(in_features=192, out_features=576, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.045)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): SwinTransformerBlock(
dim=192, input_resolution=(6, 6), num_heads=12, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=192, window_size=(6, 6), num_heads=12
(qkv): Linear(in_features=192, out_features=576, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.055)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): SwinTransformerBlock(
dim=192, input_resolution=(6, 6), num_heads=12, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=192, window_size=(6, 6), num_heads=12
(qkv): Linear(in_features=192, out_features=576, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.064)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(4): SwinTransformerBlock(
dim=192, input_resolution=(6, 6), num_heads=12, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=192, window_size=(6, 6), num_heads=12
(qkv): Linear(in_features=192, out_features=576, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.073)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(5): SwinTransformerBlock(
dim=192, input_resolution=(6, 6), num_heads=12, window_size=6, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=192, window_size=(6, 6), num_heads=12
(qkv): Linear(in_features=192, out_features=576, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.082)
(norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(downsample): PatchMerging(
input_resolution=(6, 6), dim=192
(reduction): Linear(in_features=768, out_features=384, bias=False)
(norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
(3): BasicLayer(
dim=384, input_resolution=(3, 3), depth=2
(blocks): ModuleList(
(0): SwinTransformerBlock(
dim=384, input_resolution=(3, 3), num_heads=24, window_size=3, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=384, window_size=(3, 3), num_heads=24
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.091)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=384, out_features=1536, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1536, out_features=384, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): SwinTransformerBlock(
dim=384, input_resolution=(3, 3), num_heads=24, window_size=3, shift_size=0, mlp_ratio=4
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(attn): WindowAttention(
dim=384, window_size=(3, 3), num_heads=24
(qkv): Linear(in_features=384, out_features=1152, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(softmax): Softmax(dim=-1)
)
(drop_path): DropPath(drop_prob=0.100)
(norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=384, out_features=1536, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=1536, out_features=384, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
)
)
(norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(avgpool): AdaptiveAvgPool1d(output_size=1)
(head): Linear(in_features=384, out_features=10, bias=True)
)
# set the preprocess operations to be performed on train/val/test samples
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# download STL10 training set and reserve 50000 for training
train_set = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)
# download STL10 test set
test_set = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)
# define the data loaders using the datasets
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=128, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=128, shuffle=False)
Files already downloaded and verified
Files already downloaded and verified
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
# Training loop
num_of_epochs = 200
for epoch in range(num_of_epochs):
for imgs, labels in tqdm_notebook(train_loader, desc='epoch '+str(epoch)):
# Transfer to GPU
imgs, labels = imgs.to(device), labels.to(device)
# zero the parameter gradients
model_optimizer.zero_grad()
# Make predictions for this batch
preds = model(imgs)
# Compute the loss and its gradients
loss = criterion(preds, labels)
# backpropagate the loss
loss.backward()
# adjust parameters based on the calculated gradients
model_optimizer.step()
torch.save(model.state_dict(), 'model_'+str(num_of_epochs)+'.pth')
Uncomment this line if you’re using a pretrained model.
# model.load_state_dict(torch.load('model_STL10_200_embed_dim_48.pth', map_location=torch.device('cpu')))
<All keys matched successfully>
all_labels, all_pred_labels = [], []
model.eval()
acc_total = 0
with torch.inference_mode():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = model(imgs)
pred_cls = preds.data.max(1)[1]
all_labels.append(labels.data.tolist())
all_pred_labels.append(pred_cls.data.tolist())
acc_total += pred_cls.eq(labels.data).cpu().sum()
all_labels_flat = list(itertools.chain.from_iterable(all_labels))
all_pred_labels_flat = list(itertools.chain.from_iterable(all_pred_labels))
acc = acc_total.item()/len(test_loader.dataset)
print(f'Accuracy on test set = {acc*100:.2f}%')
Accuracy on test set = 47.46%
global global_attention
global_attention = []
img = train_loader.dataset.data[2,:,:,:]
print(img.shape)
img_plot = np.transpose(img, (1, 2, 0))
plt.imshow(img_plot)
print(img.shape)
img = torch.Tensor(img)
img = img.unsqueeze(0)
img = img.to(device)
print(img.shape)
pred = model(img)
(3, 96, 96)
(3, 96, 96)
torch.Size([1, 3, 96, 96])
for i in global_attention:
print(i.shape)
torch.Size([16, 3, 36, 36])
torch.Size([16, 3, 36, 36])
torch.Size([4, 6, 36, 36])
torch.Size([4, 6, 36, 36])
torch.Size([1, 12, 36, 36])
torch.Size([1, 12, 36, 36])
torch.Size([1, 12, 36, 36])
torch.Size([1, 12, 36, 36])
torch.Size([1, 12, 36, 36])
torch.Size([1, 12, 36, 36])
torch.Size([1, 24, 9, 9])
torch.Size([1, 24, 9, 9])
img_attn = global_attention[0][:,:,:,:].cpu().detach()
print(img_attn.shape)
print(torch.sum(img_attn[0,0,:,:], dim=1))
torch.Size([16, 3, 36, 36])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
plt.imshow(global_attention[0][0,0,:,:].cpu().detach().numpy())
img_attn = global_attention[0][0,0,:,:].cpu().detach().numpy()
fig, axs = plt.subplots(img_attn.shape[0]//6, 6, figsize=(14, 12), layout="constrained")
for i, ax in enumerate(axs.ravel()):
ax.hist(img_attn[i,:], bins=10)
plt.show()
The plot_attention
function takes an attention matrix of size (num_windows, window_height, window_width)
and returns an image of size (num_windows*window_height, num_windows*window_width)
. For every patch in a window, I’m taking its value in attention matrix. Then these patches will be organized to form windows using window_reverse
function.
def plot_attention(img_plot, img_attn, plot_title):
window_size = 6
img_attn,_ = img_attn.max(axis=-1)
num_windows, num_heads, num_patches = img_attn.shape
img_attn = img_attn.reshape(num_windows, num_heads, int(num_patches**.5), int(num_patches**.5))
if num_heads <= 6:
fig, axs = plt.subplots(num_heads//3, 3, figsize=(12,6))
else:
fig, axs = plt.subplots(num_heads//3, 3, figsize=(12,12))
fig.suptitle(plot_title)
for i, ax in enumerate(axs.ravel()):
img_attn_plot = img_attn[:,i,:,:]
img_attn_plot = img_attn_plot.unsqueeze(-1)
img_H = int(num_windows**.5) * int(num_patches**.5)
img_attn_plot = window_reverse(img_attn_plot, window_size, img_H, img_H)
img_attn_plot = img_attn_plot.squeeze(0).squeeze(-1).numpy()
ax.imshow(scipy.ndimage.zoom(img_attn_plot, img_plot.shape[0]//img_attn_plot.shape[0]))
ax.axis("off")
plt.show()
plt.imshow(img_plot)
plt.axis("off")
plot_attention(img_plot, global_attention[0][:,:,:,:].cpu().detach(), 'Layer 00, SwinBlock 00')
plot_attention(img_plot, global_attention[2][:,:,:,:].cpu().detach(), 'Layer 01, SwinBlock 00')
plot_attention(img_plot, global_attention[4][:,:,:,:].cpu().detach(), 'Layer 02, SwinBlock 00')
plot_attention(img_plot, global_attention[6][:,:,:,:].cpu().detach(), 'Layer 02, SwinBlock 02')
plot_attention(img_plot, global_attention[8][:,:,:,:].cpu().detach(), 'Layer 02, SwinBlock 04')
plt.imshow(img_plot)
plt.axis("off")
plot_attention(img_plot, global_attention[1][:,:,:,:].cpu().detach(), 'Layer 00, SwinBlock 01')
plot_attention(img_plot, global_attention[3][:,:,:,:].cpu().detach(), 'Layer 01, SwinBlock 01')
plot_attention(img_plot, global_attention[5][:,:,:,:].cpu().detach(), 'Layer 02, SwinBlock 01')
plot_attention(img_plot, global_attention[7][:,:,:,:].cpu().detach(), 'Layer 02, SwinBlock 03')
plot_attention(img_plot, global_attention[9][:,:,:,:].cpu().detach(), 'Layer 02, SwinBlock 05')