Part III - Profiling a Pytorch Forward Pass
A walk-through of analyzing the original Diffusion Policy code-release to identify bottle-necks in inference and come up with a plan to tackle them.
Published
Feb 18th, 2024

Code for this post can be found here if you want to follow along. It includes the trace.json which you can directly throw into chrome://tracing and play with.

To optimize inference we must first understand where time is being spent and what the bottlenecks are. At the highest level Diffusion Policy involves through a forward pass through ResNet-50 to obtain image embeddings, some number of denoising iterations involving a U-Net, the Huggingface Diffusers library to obtain an action, and time spent in a CPU simulation of a toy task to obtain the next image observation. We can use some simple time.perf_counters() in Python to figure out the distribution of time between these three buckets. We use torch.cuda.synchronize() before every timing is measured to ensure all CUDA streams are cleared before measuring time taken for that piece of Python code. This is important because Pytorch runs asynchronously w.r.t GPU kernels. When using the Python time module without synchronizing CUDA streams, you may actually be prematurely ending the window for that timing as computations associated with that chunk of code are still in progress on the GPU. After this coarse-grained profiling we find this breakdown:

Time spent in vision encoder: 0.123s
Time spent in U-Net: 8.362s
Time spent in Diffusers: 0.617s
Time spend outside of that: 0.424s
Total time: 9.528s

From our rough first pass at profiling this network, we see that the U-Net and the diffusers library is responsible for the bulk (~94%) of total program runtime. The results above are with 100 denoising iterations per step, and 26 total steps. Given that we run 100 forward passes through the U-Net for every time the resnet or physics simulation is ran, this breakdown seems reasonable. In order to better understand how a U-Net forward pass is broken down, we can use the Pytorch profiler and HTA (Meta’s open source holistic trace analysis tool).

Tangent

Horace He's blog post does a much better job of explaining this than I do. You should give it a read here!

Deep learning models in Pytorch eager mode can spend time during a forward pass in two buckets:

1) Pytorch Land - on the CPU, running pytorch code. In eager mode all operations in forward pass need to be mapped to the appropriate GPU kernels. The logic associated with this mapping/memory allocation/etc can be pretty time consuming since much of it runs in Python.

2) CUDA Land - on the GPU, running kernels.

As mentioned earlier, Pytorch operates async w.r.t CUDA and kernel dispatch can overlap with kernel execution on the GPU.  For now, we mostly care about understanding how much time is being spent in each bucket and which kernels are most responsible for time spent in CUDA land. Through profilinga single forward pass through the U-Net/Diffusers library we find several interesting things.

1) Using the outputted trace from the profiler and putting it into HTA, we can find the breakdown of GPU time between compute, non-compute, and idle. In doing so we find over 50% of time on GPU is spent idle! This is not uncommon for programs running batch_size=1 during inference, coupled with Pytorch eager-mode. The GPU can ‘outrun’ Pytorch and finish work faster than the CPU/Pytorch can assign it.

2) Using HTA, we can also track down which CUDA kernels are most time consuming. In doing so we find that three kernels account for 72% of compute time.

These kernels are responsible for transforming a tensor from [N, C, H, W] to [N, H, W, C] and performing implicit GEMMs. Using Chrome’s tracing tool (chrome://tracing) we cansearch for these kernels in our trace and use the event flow to trace them back to the layers/operations that created them.

We find that the Conv1D block is responsible for these and more specifically, Pytorch is implementing the 1D convolutions using the following kernels:

void cudnn::ops::nchwToNhwcKernel<float, float, float, false, true, (cudnnKernelDataType_t)2>(cudnn::ops::nchw2nhwc_params_t<float>, float const*, float*)
sm86_xmma_fprop_implicit_gemm_tf32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize128x64x32_stage4_warpsize2x2x1_g1_tensor16x8x8_execute_kernel_cudnn_infer
void cudnn::ops::nhwcToNchwKernel<float, float, float, true, false, (cudnnKernelDataType_t)0>(cudnn::ops::nhwc2nchw_params_t<float>, float const*, float*)
void at::native::elementwise_kernel<128, 2, at::native::gpu_kernel_impl<at::native::CUDAFunctor_add<float> >(at::TensorIteratorBase&, at::native::CUDAFunctor_add<float> const&)::{lambda(int)#1}>(int, at::native::gpu_kernel_impl<at::native::CUDAFunctor_add<float> >(at::TensorIteratorBase&, at::native::CUDAFunctor_add<float> const&)::{lambda(int)#1})

Whats likely happening here is that convolution is implemented using a matrix multiplication (implicit GEMM) which unlocks using tensor cores for these ops. In doing so however, we incur non-significant memory overhead to enable the necessary tensor shapes.

3) This is hard to trickier to show with screenshots in a blog post, but playing with the trace in Chrome it becomes pretty apparent that the Conv1D block (consisting of 1D Conv → GroupNorm → Mish activation) is responsible for the vast majority of neural net run time.

Before getting into next steps for optimization, I’d like to note an important fact about the U-Net in question. All tensor shapes and network nodes in the compute graph are *static*. We know ahead of time exactly what operations/input shapes will be required before inference even starts. Pytorch eager mode is great for flexible model iteration but in our case the network is trained and there is no reason for pytorch to have any overhead. Ideally, we would send an input to the U-Net on the GPU, all 100 denoising iterations would happen completely on the GPU, and we would get a denoised action back to feed into the CPU simulation. With that in mind, lets go over next steps for optimizing Diffusion Policy inference.

1) Develop custom CUDA kernels for the 1D Conv Block that outperform the current implicit GEMM Cudnn kernels.

2) Develop custom CUDA kernels for the functionality currently enabled by the Diffusers library.

3) Use CUDA graphs to fuse the entire U-Net and denoising operations into a single compute graph that can be executed without Pytorch overhead.