PyTorch One-Line Optimization
Hello, welcome to Neutrino blogs!
This blog series will continuously post our new findings on GPU programming and computing via neutrino
,
our fine-grained and programmable GPU kernel observability platform like eBPF for Linux kernel!
Observing PyTorch Performance Issue with Neutrino!
As a GPU observability platform, neutrino
provides an similar user interface as its CPU siblings strace/valgrind for profiling nearly any computing workload on GPU, like Deep Learning with PyTorch.
To demonstrate its power, how about let's started with an crazily easy one like the following line:
torch.zeros((4096, 4096), dtype=torch.float16, device="cuda")
Functionality of the line is straightforward, create a 2D tensor of size 4096x4096
on GPU filled with zero in type of FP16.
And I believe most of us don't think this line has any potential performance issue, right 🧐?
Because it's too straightforward and easy 🫠.
Well computer science is a practical subject so
let's answer this question in the experiemental way with neutrino
and more specifically, neutrino
's block scheduling (block_sched
) tool:
neutrino -t block_sched python oneline.py # -t to specify tool
# output (truncated)
#
#
#
Here neutrino
suggests a kernel named vectorized_elementwise...
(weird name comes from C++ template) is executed in torch.zeros
.
This behavior is expected as the kernel is used to fill allocated buffer with 0 for correct initialization.
But the poor performance statistics is not expected, the kernel spend more than 25% time on block scheduling! What a waste of time and FLOPs!
What's block scheduling cost?
Now it's time for some GPU programming internals. Most compute platform like NVIDIA CUDA and AMD ROCm formulate GPU threads into blocks like 128 threads per block in this example. And blocks will be scheduled onto GPUs to run parallely. But one common blind spot is that parallelism of block-level is not unlimited!
In fact, the parallelism of your kernel (vectorized_elementwise
here) is bounded by GPU hardware,
like on my A100 of 108 SM and 6 block/SM (see CUDA Occupancy Calculator),
only 648 block can be executed parallely, far lower than required 65536.
In such a case, blocks will be executed sequentially other than parallely, i.e., the 649-th block need to wait for an executing block finished to run.
Similar to context switch in CPU, GPU also spends additional time (hundreds of cycles) on finalizing an executing a block and scheduling next block. And such scheduling time is called block scheduling cost here as the time is wasted from the perspective of computing:
How Neutrino measure this?
As many of us never heard of this term, block scheduling cost is mostly implicit. And one possible reason is the difficulty to measure, because GPU block scheduling is a hardware-behavior, i.e., not programmable, and on the other hand, block execution is programmable but we are trying to measure the hole between blocks!
Neutrino's block_sched
tools solve this with a combination approach.
First, in trace collection we collect three runtime statistics for every block being scheduled, starting timestamp(clock), ending timestamp and the scheduled GPU SM id.
With neutrino
, this is as simple as following toml:
Then, in post-trace analysis we can sort and analyze the scheduling per SM. Block scheduling cost can be estimated as the start clock of next block - end clock of previous block. You can check the code here.
How to optimize?
The best way to verify the correctness of Neutrino observation is to optimize in the way suggested and see if the performance is improved.
Here suggestion given by neutrino
is to reduce the block scheduling cost and one straightforward solution is to issue less blocks.
To verify this, we can use Persistent Kernel (like this example) that removes block scheduling cost via launching a persistent group of threads (at hardware capacity) and manage sequential part manually.
A persistent vectorized_elementwise
can be easily written in Triton via GPT:
- Short
- Full
@triton.jit
def zero_init_kernel_persistent(output_ptr, numel, BLOCK_SIZE: tl.constexpr, NUM_SMS: tl.constexpr):
start_pid = tl.program_id(axis=0)
num_blocks = tl.cdiv(numel, BLOCK_SIZE)
blocks_per_sm = num_blocks // NUM_SMS
if start_pid < num_blocks % NUM_SMS:
blocks_per_sm += 1
block_id = start_pid - NUM_SMS
for _ in range(blocks_per_sm):
block_id += NUM_SMS
offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < numel
tl.store(output_ptr + offsets, tl.zeros([BLOCK_SIZE], dtype=tl.float16), mask=mask)
import torch
import triton
import triton.language as tl
@triton.jit
def zero_init_kernel_persistent(output_ptr, numel, BLOCK_SIZE: tl.constexpr, NUM_SMS: tl.constexpr):
# Get program ID
start_pid = tl.program_id(axis=0)
# Calculate number of blocks needed
num_blocks = tl.cdiv(numel, BLOCK_SIZE)
# Calculate blocks per SM
blocks_per_sm = num_blocks // NUM_SMS
if start_pid < num_blocks % NUM_SMS:
blocks_per_sm += 1
# Initialize block ID
block_id = start_pid - NUM_SMS
# Process multiple blocks per SM
for _ in range(blocks_per_sm):
block_id += NUM_SMS
# Calculate offsets for this block
offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Create mask for valid elements
mask = offsets < numel
# Store zeros
tl.store(output_ptr + offsets, tl.zeros([BLOCK_SIZE], dtype=tl.float16), mask=mask)
def zero_init_persistent(x: torch.Tensor):
# Get total number of elements
numel = x.numel()
# Get number of SMs
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
# Configure BLOCK_SIZE still at 128
BLOCK_SIZE = 128
# Launch kernel
grid = lambda META: (min(NUM_SMS, triton.cdiv(numel, META['BLOCK_SIZE'])),)
zero_init_kernel_persistent[grid](x, numel, BLOCK_SIZE=BLOCK_SIZE, NUM_SMS=NUM_SMS, num_warps=8)
# Test script
if __name__ == "__main__":
# Test correctness
shape = (4096, 4096)
old = torch.zeros(shape, dtype=torch.float16, device='cuda')
new = torch.empty(shape, dtype=torch.float16, device='cuda')
zero_init_persistent(new)
assert torch.allclose(old, new)
To apply the change, we can call torch.empty
to only allocate memory without initialization and call zero_init_kernel_persistent
to initialize manually:
# 34,493ns (nsys, 1000 run
old = torch.zeros(shape, dtype=torch.float16, device='cuda')
# nsys 1000 run: 25ms
new = torch.empty(shape, dtype=torch.float16, device='cuda')
zero_init_persistent(new)
Replacing the line leads to around 33% acceleration, which supports Neutrino's observation of 25% scheduling cost.
Another lazier solution is to use cuMemset
driver call which totally eliminates the kernel launch (at least to developers outside Nvidia), which offers similar improvement:
# 34,493ns (nsys, 1000 run
old = torch.zeros(shape, dtype=torch.float16, device='cuda')
# 24,630ns (nsys, 1000 run
new = torch.empty(shape, dtype=torch.float16, device='cuda')
driver.cuMemsetD16(empty.data_ptr(), 0, 4096*4096)
🥳Well Done🥳
With neutrino
, we observe and accelerate a "mission impossible", a surprisingly easy torch.zeros
performance issue that lies in our daily workflow!
Moreover, the underlying vectorized_elementwise
not just lies in torch.zeros
but also torch.ones/arange/exp
(actually any elementwise unary operator🫣) and even plenty of implicit call to flush implicit buffers🫥!
Conclusion
In this 5-mins journey, we walk through the basic debug experience of Neutrino with a simple example as one-line of torch.zeros
, demonstrating neutrino
's amazing capability to discover performance issue from unexpected perspective.
Finally I want to conclude with an inspiring question:
Even such a simple line can have such a big issue, how many potential performance issue in current GPU can we explore and optimize with Neutrino?
Further Reading
If you are interested in neutrino
, take a look at:
- Introduction: More detailed and insightful information of Neutrino
- Probe Design: How tools like
block_sched
is designed and implemented! - Our later blog!