VisionTransformer
classparam_dtype
dtype
jax.lax.Precision
jax.default_matmul_precision
These resources were helpful in preparing this post:
We need ml_collections
to prepare the configs and grain
for dataloaders.
pip install ml_collections grain
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
from jax.experimental import mesh_utils
import jax.profiler
from flax import nnx
import optax
import grain.python as grain
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from ml_collections import ConfigDict
import time
data_config = ConfigDict(dict(
batch_size=128,
img_size=32,
seed=12,
crop_scales=(0.9, 1.0),
crop_ratio=(0.9, 1.1),
data_means=(0.4914, 0.4822, 0.4465),
data_std=(0.2023, 0.1994, 0.2010)
))
model_config = ConfigDict(dict(
num_epochs=1,
learning_rate=1e-3,
patch_size=4,
num_classes=10,
embed_dim=384,
mlp_dim=1536,
num_heads=8,
num_layers=6,
dropout_rate=0.1,
dtype=jnp.bfloat16,
param_dtype=jnp.float32,
rngs=nnx.Rngs(0)
))
We need torchvision
to import CIFAR-10 and perform random flipping and cropping. We also need the images to be in numpy arrays to be accepted by jax.
def image_to_numpy(img):
img = np.array(img, dtype=np.float32)
img = (img / 255. - np.array(data_config.data_means)) / np.array(data_config.data_std)
return img
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
# images in the test set will only be converted into numpy arrays
test_transform = image_to_numpy
# images in the train set will be randomly flipped, cropped, and then converted to numpy arrays
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((data_config.img_size), scale=data_config.crop_scales, ratio=data_config.crop_ratio),
image_to_numpy
])
train_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])
test_set = torchvision.datasets.CIFAR10('data', train=False, transform=test_transform, download=True)
train_sampler = grain.IndexSampler(
len(train_dataset),
shuffle=True,
seed=data_config.seed,
shard_options=grain.NoSharding(), # No sharding here, because it can be done inside the training loop
num_epochs=model_config.num_epochs,
)
val_sampler = grain.IndexSampler(
len(val_dataset),
shuffle=False,
seed=data_config.seed,
shard_options=grain.NoSharding(),
num_epochs=model_config.num_epochs,
)
train_loader = grain.DataLoader(
data_source=train_dataset,
sampler=train_sampler,
operations=[
grain.Batch(data_config.batch_size, drop_remainder=True),
]
)
# Test (validation) dataset `grain.DataLoader`.
val_loader = grain.DataLoader(
data_source=val_dataset,
sampler=val_sampler,
operations=[
grain.Batch(2*data_config.batch_size),
]
)
Here I’m creating a device mesh to distribute the data across 8 devices. I’m not using model parallelism.
# Create a `Mesh` object representing TPU device arrangement.
mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
print(nnx.tabulate(model, jnp.ones((1, 32, 32, 3))))
VisionTransformer
class⚠️ CAUTION: ⚠️
I did not measure the loss or accuracy on CIFAR-10, so I do not know how this code performs in terms of loss optimization. The sole purpose of this class is to monitor memory usage as we change the data type.
⚠️ CAUTION: ⚠️
I commented out all calls of nnx.with_partitioning
because it won’t work with nnx.tabulate
, which I’m going to call later. For data parallelism uncomment nnx.with_partitioning
class PatchEmbedding(nnx.Module):
def __init__(self, img_size, patch_size, embed_dim, dtype, param_dtype, rngs: nnx.Rngs(0)):
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
fh, fw = patch_size, patch_size # Filter height/width are your patch_size
self.conv_proj = nnx.Conv(
in_features=3,
out_features=embed_dim,
kernel_size=(fh, fw),
strides=(fh, fw),
padding='VALID',
# kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(self, x: jax.Array) -> jax.Array:
x = self.conv_proj(x)
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
return x
class EncoderBlock(nnx.Module):
def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate, dtype, param_dtype, rngs: nnx.Rngs):
self.norm1 = nnx.LayerNorm(epsilon=1e-6,
num_features=embed_dim,
# scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs)
self.attn = nnx.MultiHeadAttention(num_heads=num_heads,
in_features=embed_dim,
# kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs)
self.norm2 = nnx.LayerNorm(epsilon=1e-6,
num_features=embed_dim,
# scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs)
self.linear1 = nnx.Linear(
in_features=embed_dim,
out_features=mlp_dim,
# kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate, rngs=rngs)
self.linear2 = nnx.Linear(
in_features=mlp_dim,
out_features=embed_dim,
# kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate, rngs=rngs)
def __call__(self, x):
x = x + self.attn(self.norm1(x), decode=False)
mlp_out = self.norm2(x)
mlp_out = self.linear1(mlp_out)
mlp_out = nnx.gelu(mlp_out)
mlp_out = self.dropout1(mlp_out)
mlp_out = self.linear2(mlp_out)
mlp_out = self.dropout2(mlp_out)
x = x + mlp_out
return x
class VisionTransformer(nnx.Module):
def __init__(self, img_size, patch_size, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout_rate, dtype, param_dtype, rngs: nnx.Rngs = nnx.Rngs(0)):
self.patch_embed = PatchEmbedding(img_size, patch_size, embed_dim, dtype, param_dtype, rngs=rngs)
num_patches = (img_size // patch_size) ** 2
self.cls_token = nnx.Param(jnp.zeros((1, 1, embed_dim)), dtype=dtype)
self.pos_embed = nnx.Param(jax.random.normal(rngs.params(), (1, num_patches + 1, embed_dim)), dtype=dtype)
self.encoder_blocks = [
EncoderBlock(embed_dim, num_heads, mlp_dim, dropout_rate, dtype, param_dtype, rngs=rngs)
for _ in range(num_layers)
]
self.norm = nnx.LayerNorm(epsilon=1e-6,
num_features=embed_dim,
# scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs)
self.head = nnx.Linear(in_features=embed_dim,
out_features=num_classes,
# kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
# bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs)
self.dtype = dtype
def __call__(self, x):
x = x.astype(self.dtype)
x = self.patch_embed(x)
batch_size = x.shape[0]
cls_tokens = jnp.tile(self.cls_token.value, (batch_size, 1, 1))
x = jnp.concatenate((cls_tokens, x), axis=1)
x = x + self.pos_embed.value
for block in self.encoder_blocks:
x = block(x)
cls_output = self.norm(x[:, 0])
return self.head(cls_output)
param_dtype
Now let’s change our config dict to run VisionTransformer
class for each of the following choices:
param_dtype |
---|
jnp.float32 |
jnp.float16 |
jnp.bfloat16 |
Now let’s run nnx.tabulate
to check the parameters size. Here is an example ouutput of nnx.tabulate
:
VisionTransformer Summary
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ path ┃ type ┃ inputs ┃ outputs ┃ Param ┃ RngState ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ │ VisionTransformer │ float32[1,32,32,… │ bfloat16[1,10] │ cls_token: │ 2 (12 B) │
│ │ │ │ │ dtype: type │ │
│ │ │ │ │ value: │ │
│ │ │ │ │ float32[1,1,384] │ │
│ │ │ │ │ pos_embed: │ │
│ │ │ │ │ dtype: type │ │
│ │ │ │ │ value: │ │
│ │ │ │ │ float32[1,65,384] │ │
│ │ │ │ │ │ │
│ │ │ │ │ 10,695,562 (42.8 │ │
│ │ │ │ │ MB) │ │
├────────────────────┼────────────────────┼───────────────────┼────────────────────┼───────────────────┼──────────┤
│ patch_embed │ PatchEmbedding │ bfloat16[1,32,32… │ bfloat16[1,64,384] │ 18,816 (75.3 KB) │ │
├────────────────────┼────────────────────┼───────────────────┼────────────────────┼───────────────────┼──────────┤
│ patch_embed/conv_… │ Conv │ bfloat16[1,32,32… │ bfloat16[1,8,8,38… │ bias: │ │
│ │ │ │ │ float32[384] │ │
│ │ │ │ │ kernel: │ │
│ │ │ │ │ float32[4,4,3,38… │ │
│ │ │ │ │ │ │
│ │ │ │ │ 18,816 (75.3 KB) │ │
├────────────────────┼────────────────────┼───────────────────┼────────────────────┼───────────────────┼──────────┤
│ encoder_blocks/0 │ EncoderBlock │ float32[1,65,384] │ float32[1,65,384] │ 1,774,464 (7.1 │ 2 (12 B) │
│ │ │ │ │ MB) │ │
├────────────────────┼────────────────────┼───────────────────┼────────────────────┼───────────────────┼──────────┤
│ encoder_blocks/0/… │ LayerNorm │ float32[1,65,384] │ bfloat16[1,65,384] │ bias: │ │
│ │ │ │ │ float32[384] │ │
│ │ │ │ │ scale: │ │
│ │ │ │ │ float32[384] │ │
│ │ │ │ │ │ │
│ │ │ │ │ 768 (3.1 KB) │ │
├────────────────────┼────────────────────┼───────────────────┼────────────────────┼───────────────────┼──────────┤
│ encoder_blocks/0/… │ MultiHeadAttention │ - │ bfloat16[1,65,384] │ 591,360 (2.4 MB) │ │
│ │ │ bfloat16[1,65,38… │ │ │ │
│ │ │ decode: false │ │ │ │
├────────────────────┼────────────────────┼───────────────────┼────────────────────┼───────────────────┼──────────┤
This is a comparison of parameters size based on param_dtype
:
According to this plot, it is an easy decision to go with param_dtype=jnp.bfloat16
. But if we go with this choice we will run into this error:
XlaRuntimeError: UNIMPLEMENTED: Dot algorithm ALG_DOT_F16_F16_F32 is not supported.
dtype
Since we’re restricted to use param_dtype=jnp.float32
, let’s play with dtype
and test four different choices:
batch_size |
dtype |
---|---|
128 |
jnp.float32 |
128 |
jnp.bfloat16 |
4096 |
jnp.float32 |
4096 |
jnp.bfloat16 |
I was expecting the difference with larger batch size, but they were so close with batch_size=128
. After digging into (HLO Op stats) in TensorBoard and serching for this text 16,32,32,3
, which represents the input size divided into 8 devices 128/8=16
, I found a copy operation that converts the input tensor to bfloat16
:
But that copy operation was not performed with batch_size=4096
:
Honestly, I don’t know what causes this copy operation, but it is the reason why these two precisions perform the same with batch_size=128
.
jax.lax.Precision
jax.lax.Precision
has three options that control the precision: DEFAULT
, HIGH
, and HIGHEST
. Here is an example on how to use it with a linear layer:
self.head = nnx.Linear(in_features=embed_dim,
out_features=num_classes,
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
dtype=dtype,
# param_dtype=param_dtype,
precision=jax.lax.Precision.HIGHEST,
rngs=rngs)
In my experiments all three options have the same impact. This might indicate that XLA was using a certain precision and jax.lax.Precision
could not override it.
jax.default_matmul_precision
I warpped the training loop in a with jax.default_matmul_precision('float32')
. But testing two choices ‘float32’ and ‘bfloat16’, the performance was the same: