阅读本文前请先对LLM、注意力机制以及Transformer有基本的了解。完整的源代码可在 GitHub 上获取:yalm(另一个语言模型,https://github.com/andrewkchan/yalm)。
(本文作者Andrew Chen是技术专家社区South Park Commons软件工程师,毕业于加州大学伯克利分校。本文由OneFlow编译发布,转载请联系授权。原文:
1.1 推理概述
void generate(Model& model, std::string prompt, int steps) {
std::vector<int> encoded = tokenizer.encode(prompt);
InferenceState s(model);
// 1. Prefill step: Forward the model on each prompt token, discarding
// the output. This lets the model read the prompt and hydrates the KV
// cache.
for (int token : encoded) {
model.forward(s, token);
// 2. Decode step: Forward the model repeatedly, generating 1 token at a time.
for (int i = 0; i < steps; i++) {
model.forward(s, encoded.back());
int next_token = sampler.sample(s.logits);
std::cout << tokenizer.decode_one(next_token) << std::flush;
if (next_token == tokenizer.EOS) {
// InferenceState is the minimum set of buffers needed to
// hold state during the forward pass and exists to avoid
// extra allocations
void Model::forward(InferenceState& s, int token) {
// The embedding table maps token IDs to embedding vectors,
// which are copied into a buffer of the inference state
s.x = copy_embedding(token, this->token_embedding_table);
// Models consist of a sequence of transformer blocks which
// mutate the inference state in order
for (Block& b : this->blocks) {
// Usually there is a layer norm right before the final classifier
s.x = layernorm(s.x, this->lm_head_prenorm_weights);
// Typically we end with a linear transform from (dim) -> (vocab_size)
s.logits = linear(s.x, this->lm_head_classifier_weights);
void Block::block(InferenceState& s) {
s.x_resid = layernorm(s.x, this->att_prenorm_weights);
// Multi-head attention typically includes:
// 1. RoPE on input (element-wise mutation w/ sines/cosines)
// 2. QKV matmuls and updating the KV cache
// 3. Causal self-attention, softmax, and value mixing
// 4. Projection back into the residual stream
s.x_resid = multi_head_attn(
s.x += s.x_resid;
s.x_resid = layernorm(s.x, this->ffn_prenorm_weights);
// On modern architectures like Llama, this is a GLU feedforward
// with 3 linear transforms, not a simple MLP:
// -> w2(F.silu(w1(x)) * w3(x))
// Some architectures also split the FFN into a mixture of experts.
s.x_resid = ffn(s.x_resid, this->w1, this->w2, this->w3);
s.x += s.x_resid;
1.2 瓶颈与基准
现代CPU和GPU在浮点运算方面速度极快。关键指标是每秒浮点运算次数(FLOPs)与内存带宽的比率(FLOPs / byte)。例如,AMD Ryzen 7950X的这一比率约为40∶1,而英伟达RTX 4090的比率为82∶1。我服务器的AMD EPYC 7702P的这一比率没那么亮眼,但也较为可观,为10∶1。
EPYC 7702P的最大带宽 [4]:204.8GB / 秒
RTX 4090的最大带宽 [5]:1008GB / 秒
带有4096(4k)上下文窗口以及32位浮点数(FP32)键值缓存的Mistral - 7B - Instruct - v0.2 模型大小为 29516398592 字节
204.8×10⁹字节 / 秒 ÷29516398592字节 / 词元≈6.9 词元 / 秒(针对EPYC 7702P)
该模型无法装入RTX 4090的24GB显存中,所以我们跳过这部分计算。
带有4096(4k)上下文窗口以及16位浮点数(FP16)键值缓存的Mistral - 7B - Instruct - v0.2模型大小为15020875776字节
204.8×10⁹字节 / 秒 ÷15020875776字节 / 词元≈13.6 词元 / 秒(针对EPYC 7702P)
1008×10⁹字节 / 秒 ÷15020875776字节 / 词元≈67.1 词元 / 秒(针对 RTX 4090)
2.1 多线程
static void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
int i;
for (i = 0; i < d; i++) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
xout[i] = val;
// F16C code technically operates on 16-bit unsigned short integers
typedef uint16_t f16_t;
// matmul supporting float16 weights via the F16C extension, which allows
// conversion into float32 values before calculations.
static void matmul(float* xout, float* x, f16_t* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
assert(n % 16 == 0);
int i;
for (i = 0; i < d; i++) {
// Vectorized dot product of w[i][:] and x[:] where w is a packed float16 array.
__m256 sumlo = _mm256_setzero_ps();
__m256 sumhi = _mm256_setzero_ps();
for (int j = 0; j < n; j+=16) {
// Extract the next set of 16 float16 weights from `w` and store them
// to two separate float32 vectors of width 8 (`wveclo_ps`, `wvechi_ps`)
__m256i wvec = _mm256_loadu_si256((__m256i*)&w[i * n + j]);
__m128i wveclo = _mm256_extractf128_si256(wvec, 0);
__m128i wvechi = _mm256_extractf128_si256(wvec, 1);
__m256 wveclo_ps = _mm256_cvtph_ps(wveclo);
__m256 wvechi_ps = _mm256_cvtph_ps(wvechi);
// Extract the next two float32 vectors of width 8 `xveclo`, `xvechi` from `x`
__m256 xveclo = _mm256_loadu_ps(&x[j]);
__m256 xvechi = _mm256_loadu_ps(&x[j + 8]);
// Compute vectorized FMAs: sumlo += wveclo * xveclo, sumhi += wvechi * xvechi
sumlo = _mm256_fmadd_ps(wveclo_ps, xveclo, sumlo);
sumhi = _mm256_fmadd_ps(wvechi_ps, xvechi, sumhi);
// Horizontally reduce width-8 float32 vectors sumlo, sumhi to a scalar.
__m256 sum8 = _mm256_add_ps(sumlo, sumhi); // sum8[0:8] = sumlo[0:8] + sumhi[0:8]
__m128 sum4 = _mm_add_ps( // sum4[0:4] = sum8[0:4] + sum8[4:8]
_mm256_extractf128_ps(sum8, 0),
_mm256_extractf128_ps(sum8, 1)
__m128 sum1 = _mm_dp_ps(sum4, _mm_set1_ps(1.0f), 0xf1); // sum1[0] = dot(sum4, [1,1,1,1])
xout[i] = _mm_cvtss_f32(sum1);
assert(false && "float16 not supported on this platform");
在调用kernel时,通过三重尖括号<<< numBlocks, threadsPerBlock >>>来指定块数和每块的线程数,可以是int或dim3类型。
3.1 简单的CUDA移植
// mix self.w2(F.silu(self.w1(x)) * self.w3(x))
// Note this is a feedforward with a GLU, not a simple MLP.
matmul<<<c.hidden_dim, WARP_SIZE>>>(
w1(), s.xb(), c.dim, c.hidden_dim, s.hb()
matmul<<<c.hidden_dim, WARP_SIZE>>>(
w3(), s.xb(), c.dim, c.hidden_dim, s.hb2()
s.hb(), s.hb2(), s.hb()
matmul<<<c.dim, WARP_SIZE>>>(
w2(), s.hb(), c.hidden_dim, c.dim, s.xb2()
// ffn residual back into x
s.x(), s.xb2(), c.dim, s.x()
// mix self.w2(F.silu(self.w1(x)) * self.w3(x))
// Note this is a feedforward with a GLU, not a simple MLP.
matmul(s.hb(), s.xb(), w1<T>(), c.dim, c.hidden_dim);
matmul(s.hb2(), s.xb(), w3<T>(), c.dim, c.hidden_dim);
for (int i = 0; i < c.hidden_dim; ++i) {
s.hb()[i] = gelu(s.hb()[i]) * s.hb2()[i];
matmul(s.xb2(), s.hb(), w2<T>(), c.hidden_dim, c.dim);
// residual connection back into x
for (int i = 0; i < c.dim; ++i) {
s.x()[i] += s.xb2()[i];
3.2 高效矩阵乘法
void matmul(const float* A, const float* x, int n, int d, float* out) {
// A (d,n) @ x (n,) -> out (d,)
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= d) return;
float sum = 0.0;
for (int j = 0; j < n; j++) {
sum += A[n * i + j] * x[j];
out[i] = sum;
/* usage */
>>>(A, x, n, d, out);
这种方法有一个很大的问题:它无法充分利用我们的CUDA核心。以Mistral-7B为例,其transformer的输入/输出维度为4096,因此,如果我们计算输出之前的最后一个matmul,我们将启动4096个线程。但是,RTX 4090可以同时支持16384个线程!这意味着许多核心将处于空闲状态,我们无法达到最大浮点运算能力(FLOPs/s)。
inline float warp_reduce_sum(float val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0xffffffff, val, offset);
return val;
inline float matmul_row(const float* row, const float* x, int offset, int dim) {
float sum = 0.0;
for (int j = offset; j < dim; j += WARP_SIZE) {
float v = row[j] * x[j];
sum += v;
return warp_reduce_sum(sum);
void matmul(const float* A, const float* x, int n, int d, float* out) {
// A (d,n) @ x (n,) -> out (d,)
// PRECOND: Blocks are 1-D and same size as warp.
int i = blockIdx.x;
if (i >= d) return;
int offset = threadIdx.x;
float rowSum = matmul_row(&A[n * i], x, offset, n);
if (threadIdx.x == 0) {
out[i] = rowSum;
/* usage */
matmul<<<d, BLOCK_SIZE>>>(A, x, n, d, out);
3.3 融合和更高效的矩阵乘法
Section: Memory Workload Analysis
--------------------------- ------------ ------------
Metric Name Metric Unit Metric Value
--------------------------- ------------ ------------
Memory Throughput Gbyte/second 533.22
Mem Busy % 24.82
Max Bandwidth % 90.26
L1/TEX Hit Rate % 65.94
L2 Compression Success Rate % 0
L2 Compression Ratio 0
L2 Hit Rate % 2.03
Mem Pipes Busy % 28.33
--------------------------- ------------ ------------
Section: Memory Workload Analysis Tables
----- --------------------------------------------------------------------------------------------------------------
WRN The memory access pattern for stores from L1TEX to L2 is not optimal. The granularity of an L1TEX request to
L2 is a 128 byte cache line. That is 4 consecutive 32-byte sectors per L2 request. However, this kernel only
accesses an average of 1.0 sectors out of the possible 4 sectors per cache line. Check the Source Counters
section for uncoalesced stores and try to minimize how many cache lines need to be accessed per memory
----- --------------------------------------------------------------------------------------------------------------
__device__ inline float blocktranspose(float v, float def) {
// Performs block-and-warp transpose operation:
// For a block containing K warps where lane 0 contains val_k,
// this function returns:
// - For warp 0, lane K: val_k
// - For all other warps and lanes: def
int lane = threadIdx.x % warpSize;
int warp = threadIdx.x / warpSize;
// Will hold results of all warps.
// Each lane of the warp accumulates across 1 head element at a time.
// NOTE: Assumes warpSize is 32
__shared__ float sm[32];
if (lane == 0) sm[warp] = v;
return lane < blockDim.x / warpSize ? sm[lane] : def;
template <typename T>
void matmul_wide(const T* A, const float* x, int n, int d, float* out) {
// A (d,n) @ x (n,) -> out (d,)
// PRECOND: Block is 1-D and contains WPB warps.
int i = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
if (i >= d) return;
// Warp j computes sum for row at <blockIdx.x*WPB + j>
// Lane 0 of each warp will hold result
int k = threadIdx.x % warpSize;
float rowSum = matmul_row(&A[n * i], x, k, n);
// Transpose values so lane k in warp 0 contains row at <blockIdx.x*WPB + k>
// For WPB=32, this allows us to coalesce 32 float32 writes into a single 128-byte store
rowSum = blocktranspose(rowSum, 1.0);
if (threadIdx.x < blockDim.x / warpSize) {
int block_start_i = blockIdx.x * blockDim.x / warpSize;
out[block_start_i + k] = rowSum;
3.4 注意力与长上下文生成
总共有n_heads * head_dim个点积,点积大小为head_dim
总共有n_heads * seq_len个点积,点积大小为head_dim
总共有n_heads * seq_len个乘加运算(对大小为head_dim的向量进行运算)
att_mix(const float *, const float *, int, int, int, int, int, float *) (32, 128, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
Section: GPU Speed Of Light Throughput
----------------------- ------------- ------------
Metric Name Metric Unit Metric Value
----------------------- ------------- ------------
DRAM Frequency cycle/nsecond 6.27
SM Frequency cycle/nsecond 1.33
Elapsed Cycles cycle 5765045
Memory Throughput % 8.81
DRAM Throughput % 1.68
Duration msecond 4.35
L1/TEX Cache Throughput % 5.17
L2 Cache Throughput % 8.81
SM Active Cycles cycle 5685841.81
Compute (SM) Throughput % 0.47
----------------------- ------------- ------------
WRN This kernel exhibits low compute throughput and memory bandwidth utilization relative to the peak performance
of this device. Achieved compute throughput and/or memory bandwidth below 60.0% of peak typically indicate
latency issues. Look at Scheduler Statistics and Warp State Statistics for potential reasons.
att - 注意力分数,形状为 (n_heads, kv_len)
vb - 值向量,形状为 (max_seq_len, n_kv_heads, head_dim)
我们希望输出一个形状为 (n_heads, head_dim)的out张量,其中,out[q] = att[q, :] @ vb[:, q//G, :],其中 G = n_heads//n_kv_heads是分组查询注意力的组大小。
下面,我们假设KV缓存中的时间步数kv_len 等于 max_seq_len:
inline float warp_reduce_sum(float val) {for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0xffffffff, val, offset);return val;}
void att_mix(
const float* vb, // (max_seq_len, n_kv_heads, head_dim)
const float* att, // (n_heads, kv_len)
int head_dim,
int n_heads,
int n_kv_heads,
int seq_len,
int max_seq_len,
float* out // (n_heads, head_dim)) {// PRECOND: blocks are 1-D and blockDim.x == WARP_SIZE
int h = blockIdx.x;
int group_size = n_heads / n_kv_heads;
int g = h / group_size;
int i = blockIdx.y;
int offset = threadIdx.x;
int kv_stride = n_kv_heads * head_dim;
const float* atth = att + max_seq_len * h;
const float* vh = vb + head_dim * g;
float* outh = out + head_dim * h;
float sum = 0.0;for (int t = offset; t < seq_len; t += WARP_SIZE) {
sum += vh[kv_stride * t + i] * atth[t];}
sum = warp_reduce_sum(sum);if (offset == 0) outh[i] = sum;}/* usage */
dim3 tpb;
tpb.x = WARP_SIZE;
dim3 blocks;
blocks.x = n_heads;
blocks.y = head_dim;
att_mix<<<blocks, tpb>>>(
vb, att,
head_dim, n_heads, n_kv_heads,
seq_len, max_seq_len, out
多个块写入单个输出元素out[h, i]。
线程使用atomicAdd将结果累加到out[h, i]。
void att_mix(
const float* vb, // (max_seq_len, n_kv_heads, head_dim)
const float* att, // (n_heads, kv_len)
int head_dim,
int n_heads,
int n_kv_heads,
int seq_len,
int max_seq_len,
float* out // (n_heads, head_dim)) {// PRECOND: blocks are 1-D and `out` has been zeroed
int h = blockIdx.x;
int group_size = n_heads / n_kv_heads;
int g = h / group_size;
int kv_stride = n_kv_heads * head_dim;
const float* atth = att + max_seq_len * h;
const float* vh = vb + head_dim * g;
float* outh = out + head_dim * h;
int t_per_thread = seq_len / gridDim.y;
int t_start = blockIdx.y * t_per_thread;for (int i = threadIdx.x; i < head_dim; i += blockDim.x) {
float sum = 0.0;for (int t = t_start; t < t_start + t_per_thread; t++) {
sum += vh[kv_stride * t + i] * atth[t];
}atomicAdd(&outh[i], sum);}}
int max_t_per_thread = 256;
dim3 tpb;
tpb.x = warp_size;
dim3 blocks;
blocks.x = c.n_heads;
blocks.y = (kv_len + max_t_per_thread - 1) / max_t_per_thread;cudaMemset(s.xb2(), 0, c.n_heads * c.head_dim * sizeof(float));
att_mix<<<blocks, tpb>>>(
vb, s.att(),
c.head_dim, c.n_heads, c.n_kv_heads,
kv_len, c.max_seq_len, s.xb2());
void att_mix(
const float* vb, // (max_seq_len, n_kv_heads, head_dim)
const float* att, // (n_heads, kv_len)
int head_dim,
int n_heads,
int n_kv_heads,
int seq_len,
int max_seq_len,
float* out // (n_heads, head_dim)
) {
// PRECOND: blocks are 2-D (warp_size, t_stride)
int h = blockIdx.x;
int group_size = n_heads / n_kv_heads;
int g = h / group_size;
int kv_stride = n_kv_heads * head_dim;
const float* atth = att + max_seq_len * h;
const float* vh = vb + head_dim * g;
float* outh = out + head_dim * h;
int warp_id = threadIdx.y;
int t_stride = blockDim.y;
// Capacity 32 since there can be at most 32 warps in a block.
__shared__ float shared[32];
for (int i = threadIdx.x; i < head_dim; i += warpSize) {
if (warp_id == 0) {
shared[threadIdx.x] = 0;
float sum = 0.0;
for (int t = warp_id; t < seq_len; t += t_stride) {
sum += vh[kv_stride * t + i] * atth[t];
atomicAdd(&shared[threadIdx.x], sum);
if (warp_id == 0) {
outh[i] = shared[threadIdx.x];
shared[threadIdx.x] = 0;
dim3 tpb;
tpb.x = warp_size;
tpb.y = min(kv_len, max_threads_per_block / warp_size);
dim3 blocks;
blocks.x = c.n_heads;
att_mix<<<blocks, tpb>>>(
vb, s.att(),
c.head_dim, c.n_heads, c.n_kv_heads,
kv_len, c.max_seq_len, s.xb2()
3.5 KV量化和编译器陷阱
在我们之前的基准测试设置中(§1.2),llama.cpp和calm实际上使用了FP16 KV缓存条目(因为这是它们的默认设置),我们也假设相同的设置来计算速度。
将KV缓存条目从FP32量化到FP16不会带来与量化权重相同的巨大收益,因为KV缓存占总内存的比例较小。但是,我们仍然期待会有一定的收益,因为每次前向传递的总内存读取量将从15.5 GB减少到15.0 GB。这种效果应该在长上下文和注意力kernel中最为明显,因为KV缓存只在注意力中使用。
float sum = 0.0;
for (int t = warp_id; t < seq_len; t += t_stride) {
sum += vh[kv_stride * t + i] * atth[t];
// atomicAdd `sum` when we're done
float sum = 0.0;
for (int t = warp_id; t < seq_len; t += t_stride) {
sum += __half2float(vh[kv_stride * t + i]) * atth[t];
// atomicAdd `sum` when we're done
float2 sum01 = make_float2(0.0, 0.0);
for (int t = warp_id; t < seq_len; t += t_stride) {
float2 v01 = __half22float2(*((half2*)&vh[kv_stride * t + i]));
float att_t = atth[t];
// Sadly CUDA does not have float2 SIMD ops
sum01.x += v01.x * att_t;
sum01.y += v01.y * att_t;
// atomicAdd both `sum01` lanes when we're done
float2 sum01 = make_float2(0.0, 0.0);
constexpr int UNROLL = 16;
half2 v01_0; float att_0;
half2 v01_1; float att_1;
half2 v01_2; float att_2;
/* ... SNIP ... */
half2 v01_15; float att_15;
int t = warp_id;for (int ctr = 0; ctr < seq_len / t_stride; t += t_stride, ctr++) {
int ctr_mod = ctr % UNROLL;if (ctr_mod == 0) {// prefetch every UNROLL iterations
v01_##j = *((half2*)&vh[kv_stride * (t + j*t_stride) + i]); \
att_##j = atth[t + j*t_stride];PREFETCH(0)PREFETCH(1)PREFETCH(2)/* ... SNIP ... */PREFETCH(15)
float2 v01;
float att_t;
switch (ctr_mod) {
case j: v01 = __half22float2(v01_##j); att_t = att_##j; break;CASE(0)CASE(1)CASE(2)/* ... SNIP ... */CASE(15)
sum01.x += v01.x * att_t;
sum01.y += v01.y * att_t;}// Handle any loop remainder that can't be unrolledfor (; t < seq_len; t += t_stride) {
float2 v01 = __half22float2(*((half2*)&vh[kv_stride * t + i]));
float att_t = atth[t];
sum01.x += v01.x * att_t;
sum01.y += v01.y * att_t;}// atomicAdd both `sum01` lanes when we're done
不幸的是,我无法测试calm CPU,因为我的机器不支持编译所需的扩展。