Part VIII - Integrating a Custom CUDA Kernel & CUDA Graphs in Pytorch
Integration of custom CUDA kernels into Pytorch, and subsequent fusing of all kernel launches into a CUDA graph to eliminate CPU overhead.
Published
Feb 18th, 2024

Integrating CUDA Kernels into Pytorch

Now that we have some CUDA kernels that are faster than the Pytorch equivalent, we can integrate them into our forward pass! Thankfully Pytorch makes CPP extensions very straightforward to integrate. I should note that the official Pytorch tutorial for this step is pretty comprehensive and probably a better resource than this post, but I'm writing this up anyway for the sake of comprehesiveness of the overall series. There are two main methods for passing a Torch tensor into a GPU kernel. One involves just-in-time compilation of the kernels using the torch.units.cpp_extension.load() function, and the other uses Python's setuptools to create a executable module that can be imported into our Python script. Both require a .cpp file that wrap around a driver function in the actual .cu kernel. The .cpp file exposes a Python interface into our kernel using pybind11. In the cpp file we also declare the function (denoise_cuda) that we will implement inside our .cu file.


#include <torch/extension.h>  
torch::Tensor denoise_cuda(
torch::Tensor&  _model_output_,
torch::Tensor&  _sample_,
torch::Tensor&  _diffusion_constants_,
torch::Tensor&  _timestep_,
torch::Tensor&  _diffusion_noise_
);  

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("denoise", &denoise_cuda, "Denoise wrapper function",
py::arg("model_output"), py::arg("sample"), py::arg("diffusion_constants"), py::arg("timestep"), py::arg("diffusion_noise"));
}

The driver function (shown below) inside the .cu file is what will actually implement the function that was declared in the .cpp file. The driver function takes in the relevant tensor's, ensures they are contiguous, creates a new empty torch tensor for the output, shoots off a CUDA kernel launch to perform the relevant compute, and finally returns the output tensor at the end. Take careful note of the fact that we captured the current cuda stream using the getCurrentCUDAStream() call and included it as the 4th kernel launch parameter! This is not well documented in the Pytorch APIs and not including the current stream as a kernel launch parameter caused a hard to chase down bug with CUDA graphs. Peter Yuen also has a great blog post showing how to write driver functions that can handle different data-types more gracefully here.



torch::Tensor denoise_cuda(
torch::Tensor&  _model_output_,
torch::Tensor&  _sample_,
torch::Tensor&  _diffusion_constants_,
torch::Tensor&  _timestep_,
torch::Tensor&  _diffusion_noise_
){
  model_output = model_output.contiguous();
  sample = sample.contiguous();
  diffusion_constants = diffusion_constants.contiguous();
  timestep = timestep.contiguous();
  diffusion_noise = diffusion_noise.contiguous();

  float* d_model_output = model_output.data_ptr();
  float* d_sample = sample.data_ptr();
  long* d_timestep = timestep.data_ptr();
  float* d_diffusion_noise = diffusion_noise.data_ptr();

  _// Copy diffusion constants to constant memory_
  cudaMemcpyToSymbol(const_diffusion_constants, diffusion_constants.data_ptr(), diffusion_constants.numel() * sizeof(float));

  auto options = sample.options();
  auto out = torch::empty({1, 2, 16}, options);  
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();  
  denoise<<<1, 32, 0, stream>>>(d_model_output, d_sample, d_timestep, d_diffusion_noise,     		out.data_ptr());  

  return out;
}

As mentioned previously, we have two options for integrating our completed .cu and .cpp files with Pytorch. The first approach is more amenable to flexible kernel iteration, and the one I used for most of this project. This involves using the Ninja build system and the Pytorch 'load' API to compile kernels at run-time. In order to go this route we simply point the load() API towards our files and get back an object which will have our earlier python binding as a function. In the example below, we import the load API and use to generate the 'denoise_module'.


from torch.utils.cpp_extension import load
denoise_module = load(
name="denoise",
sources=["denoise.cpp", "denoise_simple.cu"],
verbose=True
)

Then in our forward pass we can call our denoise function from earlier very easily.


x = denoise_module.denoise(x, sample, self.diffusion_constants, timestep, diffusion_noise)

The JIT compiled approach is really straightforward but incurs some small amount of Ninja build overhead every time the script runs as Ninja looks for a cached version of the executable. For something more production friendly, the Pytorch guide also introduces an approach that leverages Python's setuptools module to compile the kernels ahead-of-time as part of the library build process. The end result of this approach are shared-library objects that can be imported into the Python script as a regular module and called similarly to the JIT compiled approach. To go down this route, we define a setup.py file calling the setup() API with the cpp and cu files as sources in the ext_modules parameter. Note the use of 'CUDAExtension' and 'BuildExtension' classes from torch.utils.cpp_extension in the setup command. These simplify the boilerplate that would have been required to link against torch, use nvcc to compile the .cu file, generate the output shared object, etc. Below is the setup.py from the final cleaned up library for this project, including commands to build extensions for the denoising, conv1d/group norm + mish kernels.


from setuptools import setup, find_packages  
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(  
    name='diffusion_policy_accelerated',  
    version='1.0',  
    packages=find_packages(),  
    ext_modules=[  
        CUDAExtension(  
            name='diffusion_policy_accelerated.conv1d_gnm',  
            sources=['csrc/conv1d_gnm.cpp', 'csrc/conv1d_gnm_kernel.cu'],  
        ),  
        CUDAExtension(  
            name='diffusion_policy_accelerated.denoise',  
            sources=['csrc/denoise.cpp', 'csrc/denoise_kernel.cu'],  
        ),  
    ],  
    cmdclass={  
        'build_ext': BuildExtension  
    }  
)

We can run the setup process using the 'pip install -e .' command and we'll find two new shared-object files inside the 'diffusion_policy_ accelerated' folders which are compiled Python extensions that can be imported into a Python script (in the diffusion_ policy_accelerated module) just like any other import. Example usage for the denoising kernel is shown below.


import denoise
x = denoise.denoise(x, sample, self.diffusion_constants, timestep, diffusion_noise)

Using CUDA Graphs in Pytorch

Great, now we have our kernels integrated into Pytorch! The last, and biggest optimization we'll make, is eliminating almost all CPU-overhead with the use of a CUDA graph. CUDA graph's were introduced by NVIDIA in 2019 to eliminate CPU-side kernel launch overhead by allowing the programmer to record a stream of kernels, and play back a cached graph of the recorded kernel sequence.

[?]

The nodes in the graph consist of a kernel and hard-coded function parameters (most likely scalars & pointers). Since the function parameters in a graph are constant, flexibility in input is achieved by moving data in and out of the memory space that was allocated during the initial graph capture. In this way, the pointers can remain valid for future runs of the static graph. Due to the brittlness of the graph tracing, CUDA graphs cannot easily support input-dependent graph modification at run-time. Thankfully the U-Net we are working with is fully static, with tensor shapes & layers being fully input-independent. Pytorch makes this tracing & replay process pretty easy to implement. In the examples below, I show how we can use CUDA graphs to fuse the U-Net in Diffusion Policy.

We start off by creating the input input tensors for the U-Net.


static_noisy_action = torch.randn((1, config.PRED_HORIZON, config.ACTION_DIM), _requires_grad_=False, _device_=config.DEVICE)
static_obs_cond = torch.randn((1, config.IMG_EMBEDDING_DIM), _requires_grad_=False, _device_=config.DEVICE)
static_k = torch.tensor(0, _requires_grad_=False, _device_=config.DEVICE)

Next, we run a few forward passes through the net in a 'side-stream' to warmup Pytorch execution of the kernels for the network.


s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s), torch.no_grad():
    for _ in range(3):
        _ = u_net(static_noisy_action,
                  static_k,
                  static_obs_cond)
torch.cuda.current_stream().wait_stream(s)

Next, create a CUDA graph object and run a forward pass through the network with the graph as context.


u_net_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(u_net_graph), torch.no_grad():
    static_model_output = u_net(static_noisy_action,
                                static_k,
                                static_obs_cond)

Finally, we can use the CUDA graph by replaying it after copying the proper inputs into our static tensors. To elaborate here, CUDA graphs will launch kernels that take the same pointers for memory addresses every time. For that reason, we can't send new Pytorch tensor objects into the graph and instead have to copy the data we want to input into the pre-existing tensors.


noisy_action = torch.randn((1, config.PRED_HORIZON, config.ACTION_DIM), _requires_grad_=False, _device_=config.DEVICE)
obs_cond = torch.randn((1, config.IMG_EMBEDDING_DIM), _requires_grad_=False, _device_=config.DEVICE)
k = torch.tensor(0, _requires_grad_=False, _device_=config.DEVICE)
diffusion_noise = torch.randn((1, config.PRED_HORIZON, config.ACTION_DIM), _device_=config.DEVICE)
static_k.copy_(k)
static_noisy_action.copy_(noisy_action)
static_obs_cond.copy_(obs_cond)
static_diffusion_noise.copy_(diffusion_noise)
u_net_graph.replay()

Now when we replay the graph the 'static_model_ output' will contain the result from the forward pass without any Pytorch overhead! For reasons I don't understand, kernels in a CUDA graph also run faster than if they were launched by CPU processes. I imagine it has something to do with overlap of launch overhead & computation that can only happen if the next kernel launch is known ahead-of-time.