Anyone who works with PyTorch model code starts asking the same questions:
Why is this taking so long? How do I make my training loop faster?
Whether you’re an ML engineer, a researcher or just decided to play around with a random ML repository over the weekend, you will eventually try to understand how to speed your code up.
However, before we can do that, we need to learn how to measure performance correctly. And then draw the right conclusions from these measurements. This article is about exactly that, about properly benchmarking CUDA or PyTorch code.
Execution flow
Let’s use matrix multiplication on an H100 as the running example throughout.
However, before we think about measurements, let’s get the terminology straight. By “kernel” we mean a function (or really any operation) that runs on a GPU. Your code might have a lot of these, from tens to thousands of operations during each iteration of the training loop, for example:
torch.zeros((16, 16), device=”cuda)a + b, whereaandbare tensors in GPU memorya @ bflash_attention(q, k, v)- and so on
Despite the kernel being executed on a GPU, its launch is controlled by the CPU. If we imagine a Python file with the training code, it alternates between “preparatory” work, such as for-loops or if-statements, and actual kernel launches. Because of this, the CPU runs ahead and whenever it hits a kernel, it schedules that kernel for execution on the GPU by adding it to a queue. When the next kernel gets added to the queue, there is no guarantee that the previous one has already finished.
This allows the GPU to avoid sitting idle: it can always find work in that kernel queue (except when synchronization is explicitly called or when execution is CPU-bound).
The naïve approach and why it’s wrong
When we first think about measuring time in Python, the natural instinct is to reach for the time module and run something like this:
import time
import torch
def matmul(a, b):
return a @ b
def get_data(m, n, k):
a = torch.randn(m, n, device="cuda", dtype=torch.bfloat16)
b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
return a, b
def benchmark_naive(m, n, k, n_iters: int = 100):
a, b = get_data(m, n, k)
times = []
for _ in range(n_iters):
start = time.perf_counter_ns()
matmul(a, b)
end = time.perf_counter_ns()
times.append(end - start)
# perf_counter_ns returns nanoseconds, convert to ms
return sum(times) / n_iters / 1e6
small_shapes_time = benchmark_naive(16, 32, 16)
large_shapes_time = benchmark_naive(4096, 8192, 4096)
print(f"(16, 32) by (32, 16) matmul time: {small_shapes_time:.5f} ms")
print(f"(4096, 8192) by (8192, 4096) matmul time: {large_shapes_time:.5f} ms")
# (16, 32) by (32, 16) matmul time: 0.01415 ms
# (4096, 8192) by (8192, 4096) matmul time: 0.01350 ms
If we run this code on two different sets of shapes: 16x32 and 4096x8192, we’d see that there’s basically no difference in matmul time. In fact, in my run, it is actually faster to multiply matrices which are 256 times larger. This result is very suspicious and should make us think about what this code actually measures.
Actually, we never wait for the GPU matmul to finish. Since all the “preparatory” code runs on the CPU, time.perf_counter_ns() is actually measuring the time it takes for us to schedule the matmul kernel into the GPU queue and move on with our lives.
What we can do is to add torch.cuda.synchronize() after the matmul to force the CPU to wait for the GPU. But that’s still not ideal, because we’re measuring elapsed wall-clock time on the CPU, which includes scheduling overhead. We do not isolate the actual GPU execution this way.
Measuring with CUDA events
The correct way to measure the GPU time instead of the CPU time is with CUDA events. These are markers we can insert directly into the GPU’s execution stream. The GPU records timestamps whenever it reaches each CUDA event, giving us the actual execution time on the device.
import torch
def matmul(a, b):
return a @ b
def get_data(m, n, k):
a = torch.randn(m, n, device="cuda", dtype=torch.bfloat16)
b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
return a, b
def benchmark_cuda_events(m, n, k, n_iters: int = 100):
a, b = get_data(m, n, k)
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)]
for it in range(n_iters):
start_events[it].record()
matmul(a, b)
end_events[it].record()
torch.cuda.synchronize()
times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)]
return sum(times) / n_iters
small_shapes_time = benchmark_cuda_events(16, 32, 16)
large_shapes_time = benchmark_cuda_events(4096, 8192, 4096)
print(f"(16, 32) by (32, 16) matmul time: {small_shapes_time:.5f} ms")
print(f"(4096, 8192) by (8192, 4096) matmul time: {large_shapes_time:.5f} ms")
# (16, 32) by (32, 16) matmul time: 0.02093 ms
# (4096, 8192) by (8192, 4096) matmul time: 0.34025 ms
The events get inserted into the same stream as our matmul kernel. Setting enable_timing=True ensures that timestamps are recorded. The .elapsed_time method gives us the difference between the start and end timestamps.
Now we get real numbers, not absurdly fast results from the naive approach. From this benchmark, the matmul with large shapes takes at least 15x longer.
It is also worth noting that it’s better to include warmup iterations at the beginning, because there are often one-time costs included in the first launch, such as the JIT compilation. We will include warmup in the next measurements.
The L2 cache
Another issue with our current measurement is the lack of an L2 cache flush. When we run the same kernel repeatedly with the same data (and this is exactly what we’re doing), that data stays in the GPU’s L2 cache.
In NVIDIA GPUs, there’s a memory hierarchy: HBM is the largest and slowest, followed by progressively smaller and faster memory units (L2 cache → L1 cache or Shared Memory → register memory). The L2 cache is a relatively large chunk of memory that reduces the main memory accesses.
|
GPU |
Generation |
L2 per GPU |
HBM per GPU |
|---|---|---|---|
|
V100 |
Volta |
6MB |
32GB |
|
A100 |
Ampere |
40MB |
80GB |
|
H100 |
Hopper |
50MB |
80GB |
|
B200 |
Blackwell |
126MB |
192GB |
Because of this cache, our measurements might show better results than what we’d see in production, where data changes between iterations. To get more realistic measurements, we’d need to flush the cache. One way to do that is to allocate a large buffer and update it at the beginning of each iteration. This is, of course, a heuristic rather than a guaranteed cache flush:
import torch
def flush_l2_cache():
# On H100, the L2 cache is 50MB, so we allocate something a bit bigger
cache_size = 60 * 1024 * 1024
buffer = torch.zeros(cache_size // 4, dtype=torch.float32, device="cuda")
buffer += 1
del buffer
def matmul(a, b):
return a @ b
def get_data(m, n, k):
a = torch.randn(m, n, device="cuda", dtype=torch.bfloat16)
b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
return a, b
def benchmark_with_cache_flush(m, n, k, n_iters: int = 100):
a, b = get_data(m, n, k)
for _ in range(10):
matmul(a, b) # warmup
torch.cuda.synchronize()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_iters)]
for it in range(n_iters):
flush_l2_cache() # flush the L2 between iterations
start_events[it].record()
matmul(a, b)
end_events[it].record()
torch.cuda.synchronize()
times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)]
return sum(times) / n_iters
small_shapes_time = benchmark_with_cache_flush(16, 32, 16)
large_shapes_time = benchmark_with_cache_flush(4096, 8192, 4096)
print(f"(16, 32) by (32, 16) matmul time: {small_shapes_time:.5f} ms")
print(f"(4096, 8192) by (8192, 4096) matmul time: {large_shapes_time:.5f} ms")
# (16, 32) by (32, 16) matmul time: 0.00634 ms
# (4096, 8192) by (8192, 4096) matmul time: 0.40575 ms
Actually, there was a significant difference in the result between this run and the previous one, but in reality re-running the code several times alters the results a bit, and our difference is well within the expected variance - so for this particular problem the cache did not have a significant impact.
Built-in solutions
Do we have to remember all of this just to benchmark something? The answer is no. In fact, there’s a library called Triton, which was originally developed by OpenAI back in 2021 for writing GPU kernels in a Pythonic way. Nowadays it has a built-in testing module with proper benchmarking utils, which are capable of doing all of the above on their own. Let’s take a look:
# Source: https://github.com/triton-lang/triton/blob/main/python/triton/testing.py#L127C1-L190C64
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float], optional
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
:type return_mode: str
"""
assert return_mode in ["min", "max", "mean", "median", "all"]
di = runtime.driver.active.get_device_interface()
fn()
di.synchronize()
cache = runtime.driver.active.get_empty_cache_for_benchmark()
# Estimate the runtime of the function
start_event = di.Event(enable_timing=True)
end_event = di.Event(enable_timing=True)
start_event.record()
for _ in range(5):
runtime.driver.active.clear_cache(cache)
fn()
end_event.record()
di.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
runtime.driver.active.clear_cache(cache)
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
di.synchronize()
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
return _summarize_statistics(times, quantiles, return_mode)
In addition to what we’ve covered, do_bench allows to get several statistics out of the benchmark, not just the mean. For instance, one might request the median or the 99th percentile. There are other useful arguments, which are covered in the docstring.
To use this function, we can simply call:
import triton
bench_time = triton.testing.do_bench(lambda: matmul(a, b))
CPU-bound execution and CUDA graphs
For kernels with a very short execution time, the launch overhead might become so large that it exceeds the kernel time itself. The same happens when there’s significant CPU work (usually unoptimized) between kernel launches. This would severely affect our measurements.
Let’s add a looong for-loop into our matmul kernel, since we know that for-loops are really slow on CPU in Python:
def cpu_heavy_matmul(a, b):
cnt = 0
for _ in range(100_000):
cnt += 1
return a @ b
We’ll see that the (4096, 8192) by (8192, 4096) matmul time increased from 0.4ms to 2.2ms - more than 5x!
Now, depending on what your actual training run looks like, you might want to either eliminate the CPU-overhead from the benchmarking results or not. For example, if your training is already CPU-bound, then the same problems would arise during the production run as well. Then, most likely, it does not make sense to get rid of the CPU time for our measurements. However, in normal training runs, the CPU should not stand in the way of the optimal training performance - usually the CPU runs ahead of the GPU. In that case, even if that one kernel has a significant overhead, it won’t alter the kernel’s runtime on the GPU. Then we’d want to eliminate the overhead from our measurements.
To do this, we can use CUDA graphs. CUDA Graphs let us record a series of kernels once and replay them with minimal CPU involvement. In Triton, there’s triton.testing.do_bench_cudagraph, which provides exactly the same functionality. Using this for our new matmul, we get the initial kernel time (before the CPU overhead was introduced).
a, b = get_data(4096, 8192, 4096)
triton.testing.do_bench_cudagraph(lambda: cpu_heavy_matmul(a, b)) # back to ~0.4ms!
Of course, CUDA graphs come with their own constraints (static shapes, static control flow and
Not only systems fail - looking at data
Sometimes even perfect benchmarking methodology doesn’t tell the whole story, because the production conditions are different from the isolated environment where we run the benchmark.
For example, let’s say we want to benchmark the grouped gemm kernel inside of a Mixture-of-Experts layer. In a typical MoE layer, tokens are dynamically routed to different experts based on router probabilities. As a result, the actual workload seen by each expert (and by the grouped gemm kernel) depends on how balanced the routing is.
Naively, we’d generate random router probabilities and artificially route tokens based on them. However, throughout training, our router might be imbalanced, utilizing several experts much more heavily than the others.
Let’s see how the underlying distribution of tokens routed to each expert affects our measurements.
import numpy as np
import torch
import triton
def sample_expert_assignments(
seq_len: int,
num_experts: int,
top_k: int,
use_beta: bool = True,
alpha: float = 1.0,
beta: float = 5.0,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_beta:
expert_weights = np.random.beta(alpha, beta, num_experts)
expert_weights = expert_weights / expert_weights.sum()
else:
expert_weights = np.ones(num_experts) / num_experts
gumbel_noise = -np.log(-np.log(np.random.uniform(0, 1, (seq_len, num_experts))))
log_weights = np.log(expert_weights + 1e-10)
scores = log_weights[None, :] + gumbel_noise
expert_indices = torch.from_numpy(np.argsort(scores, axis=1)[:, -top_k:])
tokens_per_expert = torch.bincount(
expert_indices.flatten(), minlength=num_experts
).to(torch.int32)
return expert_indices, tokens_per_expert
def get_tokens(
seq_len: int,
hidden: int,
num_experts: int,
top_k: int,
expert_indices: torch.Tensor,
tokens_per_expert: torch.Tensor,
device: str = "cuda",
) -> torch.Tensor:
x_original = torch.randn((seq_len, hidden), dtype=torch.bfloat16, device=device)
token_indices = torch.arange(seq_len, device=device)[:, None].expand(-1, top_k)
expert_flat = expert_indices.to(device).flatten()
token_flat = token_indices.flatten()
token_sorted = token_flat[torch.argsort(expert_flat)]
return x_original[token_sorted]
def get_tensors(
seq_len: int = 1024,
hidden: int = 4096,
intermediate: int = 1536,
num_experts: int = 128,
top_k: int = 8,
use_beta: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
expert_indices, tokens_per_expert = sample_expert_assignments(
seq_len,
num_experts,
top_k,
use_beta=use_beta,
)
x = get_tokens(
seq_len,
hidden,
num_experts,
top_k,
expert_indices,
tokens_per_expert,
device="cuda",
)
w = torch.randn(
(num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
)
offsets = torch.cumsum(tokens_per_expert, dim=0).to(
dtype=torch.int32, device="cuda"
)
return x, w, offsets
def benchmark(
seq_len: int = 1024,
hidden: int = 4096,
intermediate: int = 1536,
num_experts: int = 128,
top_k: int = 8,
use_beta: bool = True,
num_iters: int = 250,
):
x, w, offsets = get_tensors(
seq_len,
hidden,
intermediate,
num_experts,
top_k,
use_beta,
)
return triton.testing.do_bench_cudagraph(
lambda: torch._grouped_mm(x, w, offs=offsets), rep=num_iters
)
params = {
"seq_len": 4096,
"hidden": 4096,
"intermediate": 1536,
"num_experts": 128,
"top_k": 8,
}
uniform = benchmark(**params, use_beta=False)
beta = benchmark(**params, use_beta=True)
print(f"Uniform: {uniform:.5f} ms")
print(f"Beta: {beta:.5f} ms")
# Uniform: 0.96811 ms
# Beta: 1.02422 ms
With the uniform token assignment, the GroupedGEMM kernel is 6% faster compared to when we sample from the beta distribution (which we use to model imbalanced load)! In practice, this could lead us to seeing discrepancies between the performance during an actual training run vs. what we observed during benchmarking.
Putting everything together
To summarize:
-
When writing the benchmarking code by hand, do not forget to:
- Flush the L2 cache
- Use CUDA events
- Use cuda.synchronize to wait for the completion of all GPU work
-
Alternatively, use
triton.testing.do_benchortriton.testing.do_bench_cudagraphfor a built-in solution. -
Regardless of approach, do not forget about the underlying data you’re using.
All in all, your benchmark is only as useful as its relevance to what we’ll actually see in real production runs. If we care about CPU overhead, do not get rid of it. If you expect a specific data distribution during training, try to sample from it while benchmarking.