logoNeutrino

High Level Information

As a toolkit mainly builds on driver and assembly layer, neutrino could be short in capturing higher-level, framework-unique information that are also important for analysis.

Following are utilities that can be used for such analysis.

Recording Tensor Shapes

The most fundamental information is the shape of tensor, and neutrino builds a small tool based on the Python's built-in sys.settrace to capture them via saving.

Usage

from neutrino.utils.tensortrace import TensorTrace

with TensorTrace() as t:
    ... # put your tensors to be captured here

When using with neutrino, captured tensors will be dumped as a file named tensor.trace in the same folder as event.log, like the following:

[call]  1745243368912002086  torch.Size([4, 32, 4096, 64])  67108864  140007813677056  q  forward  /home/root/workdir/neutrino/test/triton_.py:405
[call]  1745243368912131786  torch.Size([4, 32, 4096, 64])  67108864  140007746568192  k  forward  /home/root/workdir/neutrino/test/triton_.py:405
[call]  1745243368912146329  torch.Size([4, 32, 4096, 64])  67108864  140007679459328  v  forward  /home/root/workdir/neutrino/test/triton_.py:405
[call]  1745243368912764072  torch.Size([4, 32, 4096, 64])  67108864  140007813677056  Q  dynamic_func  <string>:1
[call]  1745243368912789106  torch.Size([4, 32, 4096, 64])  67108864  140007746568192  K  dynamic_func  <string>:1
[call]  1745243368912799835  torch.Size([4, 32, 4096, 64])  67108864  140007679459328  V  dynamic_func  <string>:1
[call]  1745243368912810325  torch.Size([4, 32, 4096])  2097152  140008384102400  M  dynamic_func  <string>:1
[call]  1745243368912822484  torch.Size([4, 32, 4096, 64])  67108864  140007612350464  Out  dynamic_func  <string>:1

This trace file are separated by two space " " and is in form of:

  • timestamp
  • shape
  • size (bytes)
  • data pointer (base 16)
  • tensor name
  • function name
  • function source (path)

If you're using callback, this tensor trace can be found via os.path.join(os.path.dirname(os.path.dirname(path)), "tensor.trace")

And this tool can be used standalone (without neutrino), and tensors captured will be printed out directly to stderr.

Known Issues

  1. Not yet support JAX.
  2. Current support for PyTorch is kind of troublesome, please wrap the call inside a function:
change
torch.matmul(a, b) # original

# Please change to the following
def matmul(a, b):
    return torch.matmul(a, b)

matmul(a, b)

Call for Improvement

This is because the tool is based on Python function frame, but I don't know why it cannot capture the frame of torch.matmul...

If you're a Python expert, please help improve it!