High Level Information
As a toolkit mainly builds on driver and assembly layer,
neutrino
could be short in capturing higher-level, framework-unique information
that are also important for analysis.
Following are utilities that can be used for such analysis.
Recording Tensor Shapes
Code in neutrino/utils/tensortrace
The most fundamental information is the shape of tensor,
and neutrino
builds a small tool based on the Python's built-in sys.settrace
to capture them via saving.
Usage
from neutrino.utils.tensortrace import TensorTrace
with TensorTrace() as t:
... # put your tensors to be captured here
When using with neutrino
, captured tensors will be dumped as a file named tensor.trace
in the same folder as event.log
, like the following:
[call] 1745243368912002086 torch.Size([4, 32, 4096, 64]) 67108864 140007813677056 q forward /home/root/workdir/neutrino/test/triton_.py:405
[call] 1745243368912131786 torch.Size([4, 32, 4096, 64]) 67108864 140007746568192 k forward /home/root/workdir/neutrino/test/triton_.py:405
[call] 1745243368912146329 torch.Size([4, 32, 4096, 64]) 67108864 140007679459328 v forward /home/root/workdir/neutrino/test/triton_.py:405
[call] 1745243368912764072 torch.Size([4, 32, 4096, 64]) 67108864 140007813677056 Q dynamic_func <string>:1
[call] 1745243368912789106 torch.Size([4, 32, 4096, 64]) 67108864 140007746568192 K dynamic_func <string>:1
[call] 1745243368912799835 torch.Size([4, 32, 4096, 64]) 67108864 140007679459328 V dynamic_func <string>:1
[call] 1745243368912810325 torch.Size([4, 32, 4096]) 2097152 140008384102400 M dynamic_func <string>:1
[call] 1745243368912822484 torch.Size([4, 32, 4096, 64]) 67108864 140007612350464 Out dynamic_func <string>:1
This trace file are separated by two space " "
and is in form of:
- timestamp
- shape
- size (bytes)
- data pointer (base 16)
- tensor name
- function name
- function source (path)
If you're using callback, this tensor trace can be found via os.path.join(os.path.dirname(os.path.dirname(path)), "tensor.trace")
And this tool can be used standalone (without neutrino
), and tensors captured will be printed out directly to stderr
.
Known Issues
- Not yet support JAX.
- Current support for PyTorch is kind of troublesome, please wrap the call inside a function:
torch.matmul(a, b) # original
# Please change to the following
def matmul(a, b):
return torch.matmul(a, b)
matmul(a, b)
Call for Improvement
This is because the tool is based on Python function frame, but I don't know why it cannot capture the frame of torch.matmul
...
If you're a Python expert, please help improve it!