Mashaan blog

Profiling JAX with TesnorBoard

Contents

Acknowledgment

These resources were helpful in preparing this post:

What XLA is actually doing?

In a previous post I wrote about running a vision transformer (ViT) using JAX device mesh. I tested different mesh setups and batch sizes. I was only looking at the runtime and memory consumption. But in this post I want to dive deeper into the time consumed by certain HLO Ops and see whether the machine is spending time in doing computations or communicating between devices. This figure XLA compilation workflow and optimization steps:

XLA-compilation

source: Intel® Extension for TensorFlow

XLA-HLO-fusion

source: Frostig et al. (2018)

Which HLO Ops taking time?

Here’s a plot of Average Step Time (i.e., runtime time) of a parallel vision transformer using an 8 by 1 device mesh on JAX:

runtime

📌 NOTE:
I used TensorBoard for profiling.
I already covered how to run TensorBoard profiler locally here.

If we go to (HLO Op Stats) in TensorBoard we can see a breakdown of which HLO Ops taking time:

HLO Op Stats pie chart

With batch_size=128, 7.1% of the time was spent performing %all-reduce.104 HLO op. But with batch_size=4096, 3.1% of the time was spent performing %fusion.253 HLO op.

Finding Time Consuming HLO Ops in Trace Viewer

Now, let’s have a look at the positions of these two operations in the Trace Viewer:

Trace Viewer batch_size 128


Trace Viewer batch_size 4096

Understanding the AllReduce HLO Operation

%all-reduce.104 operation took 7.1% of the time with batch_size=128. In fact AllReduce operation was the same for both batch sizes 128 and 4096. But when the batch size was small we spent most of the time synchronizing parameters because this is what AllReduce is doing.

Here’s the syntax of AllReduce linked with what I found in my profiling:

breaking down allreduce

We can confirm this syntax from Graph Viewer:

allreduce graph viewer

%fusion.253 HLO Operation

%fusion.253 operation took 3.1% of the time with batch_size=4096. It is a big fusion operation as shown in Graph Viewer:

fusion 253 zoom 1

Here’s a zoomed view for the same operation:

fusion 253 zoom 2

Inside %fusion.253 the batch dimension was 256. I was expecting 512 because this was the batch size per TPU device 4096/8=512. But eventually the batch dimension on the output tensor was 512, suggesting that XLA has done splitting along the way. Also the output tensor has 1536 on the channel dimension suggesting that this operation occur inside the mlp block in the vision transformer.

fusion 253 zoom 3