%fusion.253
HLO OperationThese resources were helpful in preparing this post:
In a previous post I wrote about running a vision transformer (ViT) using JAX device mesh. I tested different mesh setups and batch sizes. I was only looking at the runtime and memory consumption. But in this post I want to dive deeper into the time consumed by certain HLO Ops and see whether the machine is spending time in doing computations or communicating between devices. This figure XLA compilation workflow and optimization steps:
source: Intel® Extension for TensorFlow
source: Frostig et al. (2018)
Here’s a plot of Average Step Time (i.e., runtime time) of a parallel vision transformer using an 8 by 1 device mesh on JAX:
📌 NOTE:
I used TensorBoard for profiling.
I already covered how to run TensorBoard profiler locally here.
If we go to (HLO Op Stats) in TensorBoard we can see a breakdown of which HLO Ops taking time:
With batch_size=128
, 7.1% of the time was spent performing %all-reduce.104
HLO op. But with batch_size=4096
, 3.1% of the time was spent performing %fusion.253
HLO op.
Now, let’s have a look at the positions of these two operations in the Trace Viewer:
%all-reduce.104
operation took 7.1% of the time with batch_size=128
. In fact AllReduce operation was the same for both batch sizes 128 and 4096. But when the batch size was small we spent most of the time synchronizing parameters because this is what AllReduce is doing.
Here’s the syntax of AllReduce linked with what I found in my profiling:
We can confirm this syntax from Graph Viewer:
%fusion.253
HLO Operation%fusion.253
operation took 3.1% of the time with batch_size=4096
. It is a big fusion operation as shown in Graph Viewer:
Here’s a zoomed view for the same operation:
Inside %fusion.253
the batch dimension was 256. I was expecting 512 because this was the batch size per TPU device 4096/8=512
. But eventually the batch dimension on the output tensor was 512, suggesting that XLA has done splitting along the way. Also the output tensor has 1536 on the channel dimension suggesting that this operation occur inside the mlp block in the vision transformer.