General Matrix-Vector Multiplication (GEMV)#
===========================================
Warning
This document is still experimental and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
Tip
Example code can be found at examples/gemv/example_gemv.py
.
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using TileLang
.
Triton Implementation#
When implementing a GEMV kernel, you might start with a high-level approach using a tool like Triton
.
A simple Triton kernel for GEMV might look like this:
@triton.jit
def _gemv_naive(
x_ptr, A_ptr, y_ptr,
N, K,
BLOCK_SIZE_K: tl.constexpr,
):
n = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_SIZE_K)
mask = offs_k < K
a_ptrs = A_ptr + n * K + offs_k
a_vals = tl.load(a_ptrs, mask=mask, other=0.0)
x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0)
dot = tl.sum(a_vals * x_vals, axis=0)
tl.store(y_ptr + n, dot)
Triton
is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in TileLang
that exposes more low-level control.
Naive Implementation in TileLang#
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a (1, k) * (k, n)
GEMM. Below is a simple example:
def naive_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x
A_shared = T.alloc_shared((BLOCK_K,), dtype)
B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype)
C_reg = T.alloc_local((1,), accum_dtype)
T.clear(C_reg)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for tk in T.serial(BLOCK_K):
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
And your kernel will be compiled into CUDA by TileLang
(in ~/.tilelang/cache
):
extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_reg[1];
__shared__ uint64_t _mbarrier[2];
if (((int)threadIdx.x) == 0) {
tl::mbarrier_init(_mbarrier[0], 128);
tl::mbarrier_init(_mbarrier[1], 128);
}
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
for (int bk = 0; bk < 8; ++bk) {
tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1));
for (int tk = 0; tk < 128; ++tk) {
((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)];
((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)];
}
tl::fence_proxy_async();
tl::mbarrier_cp_async_arrive(_mbarrier[0]);
tl::mbarrier_arrive(_mbarrier[0]);
}
} else {
tl::warpgroup_reg_alloc<240>();
C_reg[0] = 0.000000e+00f;
for (int bk_1 = 0; bk_1 < 8; ++bk_1) {
tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1));
for (int tk_1 = 0; tk_1 < 128; ++tk_1) {
C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)])));
}
tl::fence_proxy_async();
tl::mbarrier_arrive(_mbarrier[1]);
}
C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]);
}
}
In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block).
At this level, we only gain very little computation power from our GPU with around ~0.17 ms compared to torch/cuBLAS’s ~0.008 ms, which is around 20x slower.
More Concurrency#
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like atomicAdd
in CUDA.
Here’s a simplified version:
def naive_splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((1,), dtype)
B_local = T.alloc_local((1,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
A_local[0] = A[bk * BLOCK_K + tk]
B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
By introducing parallelism along K dimension, our kernel now achieves ~0.024 ms, an improvement, but still not on par with torch/cuBLAS.
Customizing Parallelism in K Dimension#
If your K dimension is large, you can further customize how many elements each thread processes by introducing a reduce_threads
parameter. This way, each thread handles multiple elements per iteration:
def splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.serial(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
Vectorized Reads#
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., float2
, float4
). In TileLang
, you can specify vectorized operations via T.vectorized
:
def splitk_gemv_vectorized(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
With vectorized read, now the kernel finishs in ~0.0084 ms, which is getting close to cuBLAS performance.
tvm_thread_allreduce
Instead of atomicAdd
#
tvm_thread_allreduce
has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + atomidAdd
:
def splitk_gemv_vectorized_tvm(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
With this optimization, the kernel latency now reduces from ~0.0084 ms to ~0.0069 ms, which is faster than torch/cuBLAS!
Autotune#
BLOCK_N
, BLOCK_K
, reduce_threads
are hyperparameters in our kernel, which can be tuned to improve performance. We can use the tilelang.autotune
feature to automatically search for optimal configurations:
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
@autotune(
configs=get_configs(),
keys=[
"BLOCK_N",
"reduce_threads",
],
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
return kernel()
After autotuning, now our kernel gets ~0.0067 ms, the final generated CUDA kernel might like this:
extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
float C_accum[1];
half_t A_local[8];
half_t B_local[8];
__shared__ float red_buf0[64];
C_accum[0] = 0.000000e+00f;
for (int bk = 0; bk < 4; ++bk) {
*(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8)));
*(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8)));
for (int k = 0; k < 8; ++k) {
C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k])));
}
}
tl::fence_proxy_async();
__syncthreads();
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0];
__syncthreads();
if (((int)threadIdx.y) < 16) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]);
}
__syncthreads();
if (((int)threadIdx.y) < 8) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]);
}
__syncthreads();
if (((int)threadIdx.y) < 4) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]);
}
__syncthreads();
if (((int)threadIdx.y) < 2) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]);
}
__syncthreads();
if (((int)threadIdx.y) < 1) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]);
}
__syncthreads();
C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]);
}
This corresponds closely to our TileLang
program, with necessary synchronization and low-level optimizations inserted automatically.
Conclusion#
Benchmark Table on Hopper GPU#
Kernel Name |
Latency |
---|---|
torch/cuBLAS |
0.00784 ms |
Triton |
0.00773 ms |
naive_gemv |
0.16607 ms |
splitk_gemv |
0.02419 ms |
splitk_gemv_vectorized |
0.00809 ms |
splitk_gemv_vectorized_tvm |
0.00675 ms |
Triton Time: 0.0077344514429569244
In this tutorial, we implemented a simple GEMV kernel and learn that TileLang
exposes low level control to user such as thread-level programming and CUDA primitives.