Skip to content
285 changes: 284 additions & 1 deletion .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,287 @@ See also: [experimental support for Windows](WINDOWS.md).

If you have PyTorch on XPU installed from [binaries](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html#binaries), you already have Triton installed and don't need any additional installations, unless you want to use the latest version of Triton from `main`.

You can check existing installation by running one of the [tutorials](/python/tutorials/01-vector-add.py).
You can check if triton is currently available by running one of the [tutorials](/python/tutorials/01-vector-add.py).

# Improving performance

Basic rules:

1. **Use Tensor Descriptors:** For inputs and outputs of matmul operations (`tl.dot`), use Tensor Descriptors. This utilizes the hardware-optimized DPAS operation and asynchronous loading. You can often expect more than a 2x performance improvement compared to the basic tensor of pointers approach.
2. **Benchmark:** Experiment with the performance of your kernel. You can use `triton.testing.do_bench` for basic benchmarking, as demonstrated in the [tutorials](../python/tutorials/02-fused-softmax.py).
3. **Type Annotations:** Use proper type annotations for your kernels. Good type annotations allow for better optimization, but be careful to avoid excessive recompilation.
4. **Tiling and Autotuning:** Pick appropriate tiling for your machine and tensor shapes. Use `triton.autotune` to try various combinations and find the best one. Key parameters to tune include block sizes, `num_warps`, `num_stages`, and `grf_mode`. The Intel-specific option `grf_mode` determines the number of registers allocated to a kernel. See existing [benchmarks](../benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py) for reasonable configuration grids for GEMM and Flash Attention kernels.

## Use Tensor Descriptors to load tl.dot arguments and save results

For the Intel backend, use Tensor Descriptors to load matrices used in GEMM operations. A Tensor Descriptor can be created inside the kernel and used for loading as follows:

```
a_desc = tl.make_tensor_descriptor(
# Base of a memory block that we want to work with.
base=a_ptr,
# Shape of a tensor that starts from that base, will be used for masking.
shape=(M, K),
# Tensor strides, last dimension needs to be contiguous (=1).
# It's important that the last stride (stride_ak) is known at compile time,
# so it must have either `tl.constexpr` type annotation or no annotation at all.
strides=(stride_am, stride_ak),
# Block size that will be actually loaded.
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K)
)
m_offset = 0
# This will load (BLOCK_SIZE_M, BLOCK_SIZE_K) block from memory starting from
# a_ptr + m_offset * stride_am + k_offset * stride_ak
a = a_desc.load([m_offset, k_offset])
m_offset += BLOCK_SIZE_M
```

A Tensor Descriptor describes a piece of memory to load for processing. Loading happens in blocks, and the provided shape is used to mask out-of-bounds values to zero.


Similar code is used for saving results back to global memory:
```
c_desc = tl.make_tensor_descriptor(
base=c_ptr,
shape=(M, N),
strides=(stride_cm, stride_cn),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N)
)
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
```

You can view a full example of a GEMM kernel with Tensor Descriptors [here](../benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py).

**Before Tensor Descriptors:**
```
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
```

**After:**

```
a_desc = tl.make_tensor_descriptor(
base=a_ptr, shape=(M, K),
strides=(stride_am, stride_ak),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K)
)
b_desc = tl.make_tensor_descriptor(
base=b_ptr, shape=(K, N),
strides=(stride_bk, stride_bn),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N)
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
off_k = 0
for _ in range(0, K, BLOCK_SIZE_K):
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
accumulator += tl.dot(a, b)
off_k += BLOCK_SIZE_K
```

Tensor Descriptors internally perform offset pointer calculations and masking, so manual calculation of these variables is no longer necessary.

---
In the base case, a Tensor Descriptor describes the whole `torch.Tensor` with full shape and strides, but you can also treat it as a "view" of a tensor. You can select a specific slice of a tensor using a mask specific to that block.

For example, consider a 3D tensor `A` where we want to get a slice `A[E, :M, :]` for a Mixture of Experts (MoE) kernel. The original code might look like this (adapted from [vllm](https://github.com/vllm-project/vllm/blob/8005e606bf280b7b6002f57e95ae3210ddc6f041/vllm/model_executor/layers/fused_moe/fused_batched_moe.py#L237)):

```
# Tensor A (a_ptr) has shape [E Experts, M Tokens, K features]
# We want to load K blocks from A[expert_id, cta_m_start:min(cta_m_start+BLOCK_M, e_num_tokens), :]

# We process just one expert in this block
expert_id = tl.program_id(axis=0)
# Defines how many tokens we need to process for this expert in total
e_num_tokens = tl.load(expert_num_tokens_ptr + expert_id)
if e_num_tokens == 0:
return

cta_m_start = pid_m * BLOCK_M
if cta_m_start >= e_num_tokens:
return # Early exit

# Start of a current block, A[expert_id, cta_m_start:cta_m_start+BLOCK_M, :]
a_ptr_block = a_ptr + expert_id * stride_ae + cta_m_start * stride_am

offs_m = tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_K)

# Actual block size of M dimension that we need to process
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)

mask_m = offs_m < cta_m_size

# Each expert needs to process e_num_tokens
# Pointers to our block A[expert_id, cta_m_start:cta_m_start+BLOCK_M, 0:BLOCK_K]
a_ptrs = a_ptr_block + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak

offs_k = tl.arange(0, BLOCK_K)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(
a_ptrs,
# We mask tokens outside of e_num_token range and features larger than K
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
other=0.0)
a_ptrs += BLOCK_K * stride_ak
```


We can migrate to Tensor Descriptors to simplify the source code. Let's investigate what Tensor Desciptor we will need here.

From `a_ptrs` shape we can infer block shapes and strides of a possible Tensor Descriptor:
```
block_size=(BLOCK_M, BLOCK_K)
strides=(stride_am, stride_ak)
```

Using Tensor Desciptor we can directly describe a slice that we want to process:

`A[expert_id, :e_num_tokens, :]`

and then only load our block:

`A[expert_id, cta_m_start:cta_m_start+BLOCK_M, k*BLOCK_K:(k+1)*BLOCK_K]`.

We also want to mask tokens outside of `e_num_tokens` (`mask_m`) and features outside of dimension size `K`.
We can just pass that information as tensor shape and avoid manual masking: `shape=(e_num_tokens, K)`

Base of a tensor descriptor can be inferred from `a_ptrs` as well:

`base=a_ptr + expert_id * stride_ae`

Note that we don't add `cta_m_start * stride_am` to the base, because we can pass that offset directly during loading.

So we can rewrite that code with Tensor Desciptors, which will look much cleaner and will work faster on XPU:

```
expert_id = tl.program_id(axis=0)
e_num_tokens = tl.load(expert_num_tokens + expert_id)
if e_num_tokens == 0:
# Early exit
return

cta_m_start = pid_m * BLOCK_M
if cta_m_start >= e_num_tokens:
# Early exit
return

a_desc = tl.make_tensor_descriptor(
base=a_ptr + expert_id * stride_ae,
shape=(e_num_tokens, K),
strides=(stride_am, stride_ak),
block_shape=(BLOCK_M, BLOCK_K))

for k in range(0, tl.cdiv(K, BLOCK_K)):
a = a_desc.load([pid_m * BLOCK_M, k * BLOCK_K])
```

---


Tensor Descriptors support shapes up to 5 dimensions, but for performance, it is best to use 2 dimensions whenever possible.
Consider this example based on the unified attention kernel from [vllm](https://github.com/vllm-project/vllm/blob/9a161307f5f096c63ae4134c5055d87a36d224a8/vllm/attention/ops/triton_unified_attention.py#L52). This code loads a block of K values from a cache of shape `[NUM_BLOCKS, BLOCK_SIZE, KV_HEADS, HEAD_SIZE]`:


```
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
block_table_offset = seq_idx * block_table_stride

# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)

offs_n = tl.arange(0, BLOCK_SIZE)

k_offset = (physical_block_idx * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1)

# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None],
other=0.0)

```

The code above loads 2D block of `K[physical_block_idx, :BLOCK_SIZE, kv_head_idx, :HEAD_SIZE].T`.

A 3D Tensor Descriptor implementation following the tensor shape might look like this:

```
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
k_base = key_cache_ptr + physical_block_idx * stride_k_cache_0 + kv_head_idx * stride_k_cache_2
k_desc = tl.make_tensor_descriptor(base=k_base, shape=(BLOCK_SIZE, 1, HEAD_SIZE),
strides=(stride_k_cache_1, stride_k_cache_2, stride_k_cache_3),
block_shape=(BLOCK_SIZE, 1, HEAD_SIZE_PADDED))
K_load = k_desc.load([0, 0, 0]).reshape(BLOCK_SIZE, HEAD_SIZE_PADDED).T
```

However, describing this memory as a 2D block yields significantly better performance:

```
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
k_base = key_cache_ptr + physical_block_idx * stride_k_cache_0 + kv_head_idx * stride_k_cache_2
k_desc = tl.make_tensor_descriptor(base=k_base, shape=(BLOCK_SIZE, HEAD_SIZE),
strides=(stride_k_cache_1, stride_k_cache_3),
block_shape=(BLOCK_SIZE, HEAD_SIZE_PADDED))
K_load = k_desc.load([0, 0]).T
```

---
Summary:
1. Use Tensor Desciptors to load memory reqired for `tl.dot` and to save results.
2. Strive to use 2D tensor desctiptors for better performance.
3. Last tensor stride should be `tl.constexpr` or have no type annotation. Annotating with `tl.int64` will result in poor perfomance.

## Use proper type annotations
1. Set `tl.constexpr` type annotation for block sizes and boolean flags to let the compiler optimize. Each combination of arguments with this annotation is compiled separately. Avoid setting it for values that vary widely at runtime (like the number of tokens) to prevent excessive recompilation.
2. No Annotation: You can keep type annotations empty and let the compiler guess. This is good for parameters that change often (like strides) to avoid recompilation.
3. Avoid writing `tl.int64` type annotation for the last stride of a tensor. It is often important for the compiler to know that the tensor is contiguous.

Example of a good type annotation for a GEMM kernel:
```
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
):
```

## Tune kernel configuration

### GRF Mode
Setting it higher can be good for kernel that uses many registers, but will decrease hardware utilizaion.

# Quick Installation

Expand Down Expand Up @@ -305,6 +585,9 @@ optimized_mod = torch.compile(xpu_model)
graph_result = optimized_mod(x)
```

### Example 3 : GEMM operations
Intel backend for triton requires

## Performance Analysis Guide

There are several ways of doing performance analysis.
Expand Down
Loading