Writing Probes
neutrino
supports two level of probes:
- Tracing DSL in Python syntax, suitable for beginners.
- Assembly probes wrapped in TOML, suitable for advanced usage.
We recommend starting with Tracing DSL, like the following example,
and after running, you can find the assembly version as probe.toml
in trace folder,
which you can modify for advanced usage.
Tracing DSL Probes
As inspired by eBPF, neutrino
's program is centralized with two concept:
probe
: small code snippets (function body) with tracepoint metas (pos
,level
,before
)Map
: a structured data architecture for structured and lock-free persistence.
from neutrino import probe, Map
import neutrino.language as nl
# declare maps for persistence
@Map(level="warp", type="array", size=16, cap=1)
class BlockSched:
start: nl.u64
elapsed: nl.u32
cuid: nl.u32
# declare probe registers shared across probes
start: nl.u64 = 0 # starting clock
# define probes with decorator
@probe(pos="kernel", level="warp", before=True)
def thread_start():
start = nl.clock()
@probe(pos="kernel", level="warp")
def thread_end():
elapsed = nl.clock() - start
BlockSched.save(start, elapsed, nl.cuid())
Map
In the above example, we define a Map
named BlockSched
, as a decorated classs @neutrino.Map
insipred by Python dataclass
, with following metadatas:
level="warp"
means each warp (of 32/64 threads) will have a map, see Trace Level.type="array"
means the map is of plainarray
.size=16
means each record (inside map) takes 16 bytes. Must be a number.cap=1
means each map is capped at 1 record. Can be integer or"dynamic"
.
Then we define three "class members" forms a record:
start: nl.u64
means au64
value to save.elapsed: nl.u32
andcuid: nl.u32
defines twou32
value to save.
It's worth noting that members must be annotated with the type (nl.u64
or nl.ui32
, not Python intrinsic int
!)
since they formulates the structure of each record (8 + 4 + 4 bytes) and must be aligned with the size
(16 bytes).
Finally, map expose an .save
API to be used in probes to save, such as BlockSched.save(...)
.
Probe
Then we can start define probes as a decorated function (@neutrino.probe
, inspired by @triton.jit
)
with metadata in as decorator parameters:
level="warp"
means only warp leader will be probed, see Trace Level. Must match with the Map used.pos="kernel"
means the tracepoint is at"kernel". Can be
kernel` or any instruction.before=True
means this probe will be injected beforepos
(Herekernel
means at kernel starting). Default to beFalse
.
Bodies of the probe function (thread_start
and thread_end
functions, shall have no parameter nor returns)
are the snippets of probe with logics as Python syntax.
However, because it is not executed by Python runtime, snippets are strictly limited:
- Use of functions other than Neutrino helpers are prohibited, including Python built-ins like
print
. - Use of control statement (
if
,while
) are prohibited.
But you can still use rich Python syntax and Neutrino helpers, such as nl.clock()
to build the probe fast.
We are working on supporting control statement if
and for
.
Available Helpers
Neutrino expose helpers as fields (as alias of instruction operands) or functions (for utilities):
nl.out
,nl.in1
,nl.in2
,nl.in3
: The output and input operands of instruciton. These operands are untyped.nl.addr
: memory address of the instruction, ofnl.u64
. Only applicable to memory access instructions.nl.clock()
: read the CU-local cycle (fast but unsynced between CUs).nl.time()
: read the GPU-local timer (slow but synced between CUs).nl.cuid()
: read the index of Compute Unit / Streaming Processor the thread is scheduled on.
Contexted Register
The core logic of thread_end
is elapsed = nl.clock() - start
where we read the clock when warp ends and subtract it with start
.
But what's the value of start
?
start
is initialized in thread_start
with nl.clock()
whose value is the warp starting clock.
This is a feature of Neutrino's Execution Model named contexted registers.
Contexted registers are defined via Python assignment syntax in global scope,
such as start: nl.u64 = 0
(again, type must be specified).
Different from other values(registers) such as elapsed
that may be cleared after the function ends,
these registers have lifecycles across probes and can be used as temporal storage of runtime values (such as warp starting clock here till warp ends).
You can imagine it as a background process that will be contexted switched.
These contexted registers allows probes cooperate with states (like how eBPF probes cooperate with eBPF maps) and support more advanced usage, such as ring buffer. But the most convenient usage is still intra-kernel instruction timer like the above example.
Trace Level
The last concept is the trace level (level=...
).
Neutrino probe
and Map
mainly supports two level:
- Thread-Level (
level="thread"
): Each thread has the probe and map, mainly used for value profiling that each thread has unique register value to work on. - Warp-Level (
level="warp"
): Only warp-leader thread has the probe and map, mainly used for time profiling. Because threads within the same warp are scheduled together for the same instruction, recording timestamp only needs one thread. Other thread are masked.
Asm Probes
Since GPUs don't execute Python, the Tracing DSL are not executed but compiled (by DSL Compiler) into assembly probes for probing. Other than relying on compiler, experienced developers can handcraft asembly probes for advanced usage (compiler cannot support all possible features).
Platform Dependence
Assemblies are for particular architecture, so assembly probes are not platform-independent. For example, CUDA PTXAsm probes cannot be used for AMD (need GCNAsm probes).
To do so, you shall wrap snippets in TOML (for multi-line string support), and specify the platform
as a keyword:
platform = "cuda"
[ map.BlockSched ] # sub of "map"
type = "array"
level = "warp"
size = 16
cap = 1
data = { "start": "u64", "elapsed": "u32", "cuid": "u32" }
[ probe.thread_start_thread_end ] # sub of "probe"
position = "kernel"
level = "warp"
register = {"u32": 2, "u64": 3}
before = """.reg .b64 %PD<3>;
.reg .b32 %P<2>;
mov.u64 %PD0, %clock64;"""
after = """mov.u64 %PD1, %clock64;
sub.u64 %PD1, %PD1, %PD0;
cvt.u32.u64 %P1, %PD1;
mov.u32 %P2, %smid;
SAVE [ BlockSched ] {%PD0, %P1, %P2};"""
platform = "hip"
[ map.BlockSched ] # sub of "map"
type = "array"
level = "warp"
size = 16
cap = 1
data = { "start": "u64", "elapsed": "u32", "cuid": "u32" }
[ probe.thread_start_thread_end ] # sub of "probe"
position = "kernel"
level = "warp"
register ={"u32": 2, "u64": 3, "type": "sgpr"}
before= "s_memrealtime PD0"
after= """s_memrealtime PD1
SUB64 PD1, PD1, PD0
CVT32 P0, PD1
s_getreg_b32 P1, hwreg(HW_REG_HW_ID)
SAVE [ block_sched] {PD0, P0, P1}"""
The BlockSched
map still follows the definition, so as the thread_start_thread_end
probe.
Probes of the same pos
and level
will be merged with name concatenated.
We can find some helpers being transformed to assembly syntax (such as nl.cuid() -> mov.u32 ..., %smid;
), but some resists (such as SAVE [ BlockSched ] {...};
).
Following is the table of available helpers and correspondence:
DSL Helpers | ASM Helpers |
---|---|
Map.save(...) | SAVE [ MAP ] { ... } |
nl.out/in1/in2/in3 | OUT/IN1/IN2/IN3 |
nl.clock() | mov.u64 ..., %clock64; |
nl.time() | mov.u64 ..., %globaltimer; |
nl.cuid() | mov.u32 ..., %smid; |
DSL Helpers | ASM Helpers |
---|---|
Map.save(...) | SAVE [ MAP ] { ... } |
nl.out/in1/in2/in3 | OUT/IN1/IN2/IN3 |
nl.clock() | S_MEMTIME ... ;; |
nl.time() | S_MEMREALTIME ... ;; |
nl.cuid() | S_GETREG_B32 ..., HWREG(HW_REG_HW_ID) |