前言
本文回忆一下MHA、GQA、MQA,具体解读下MHA、GQA、MQA这三种经常出现留意力机制的原理。
图1 MHA、GQA、MQA一览
self-attention
self-attention
在自留意力机制中,输入理论是一个一致的输入矩阵,而这个矩阵后续会经过乘以不同的权重矩阵来转换成三个不同的向量汇合:查问向量Q、键向量K和值向量V。这三组向量是经过线性变换模式生成:
1.查问向量 (Q): Q=XW
2.键向量 (K): K=XW
3.值向量 (V): V=XW
W,W和W是 可学习的权重矩阵 ,区分对应于查问、键和值。这些矩阵的维度取决于模型的设计,理论它们的输入维度(列数) 是预先定义的,以满足特定的模型架构要求。 在Transformer模型中,经常使用不同的权重矩阵W,W和W来区分生成查问向量Q、键向量K和值向量V的 目标是为了准许模型在不同的示意空间中学习和抽取特色 。这样做参与了模型的灵敏性和表白才干,准许模型区分优化用于婚配(Q 和K)和用于输入消息分解(V)的示意。
在自留意力和多头留意力机制中,经常使用 作为缩放因子启动缩放操作是为了防止在计算点积时由于维度较高造成的数值稳固性疑问。这里的d是键向量的维度。 假设不启动缩放,当d较大时,点积的结果或者会变得十分大,这会造成在运行softmax函数时发生的梯度十分小。 由于softmax函数是经过指数函数计算的,大的输入值会使得局部输入凑近于1,而其余凑近于0,从而造成梯度隐没,这会在反向流传环节中形成梯度十分小,使得学习变得十分缓慢。
经过点积结果除以 ,可以调整这些值的范畴,使得它们不会太大。这样,softmax的输入在一个适合的范畴内, 有助于防止极其的指数运算结果,从而坚持数值稳固性和更有效的梯度流 。这个操作确保了即使在d很大的状况下, 留意力机制也能稳固并有效地学习。
代码成功
import torchimport torch.nn as nnimport torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, seq_length):super(SelfAttention, self).__init__()self.input_size = seq_length# 定义三个权重矩阵:Wq、Wk、Wvself.Wq = nn.Linear(seq_length, seq_length)# 线性变换self.Wk = nn.Linear(seq_length, seq_length)self.Wv = nn.Linear(seq_length, seq_length)def forward(self, input):# 计算Q,K,V 三个矩阵q = self.Wq(input)k = self.Wk(input)v = self.Wv(input)# 计算QK^T,即向量之间的相关度attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))# 计算向量权重,softmax归一化attention_weight = F.softmax(attention_scores, dim=-1)# 计算输入output = torch.matmul(attention_weight, v)return outputx = torch.randn(2, 3, 4)Self_Attention = SelfAttention(4)# 传入输入向量的维度output = Self_Attention(x)print(output.shape)
MHA(多头留意力)
Transformer 编码器块内的缩放点积留意力机制和多头留意力机制
MHA计算环节
代码成功
import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)self.wk = nn.Linear(embed_dim, embed_dim)self.wv = nn.Linear(embed_dim, embed_dim)self.wo = nn.Linear(embed_dim, embed_dim)def mh_split(self, hidden):batch_size = hidden.shape[0]x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)return xdef forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)# 留意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 拼接多头output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.rand(2, 3, 36)print(x)output = MultiHeadAttention(36, 6)y = output(x)print(y.shape)
MHA 能够了解输入不同局部之间的相关。但是,这种复杂性是有代价的——对内存带宽的需求很大,尤其是在解码器推理时期。重要疑问的关键在于内存开支。 在自回归模型中,每个解码步骤都须要加载解码器权重以及一切留意键和值。这个环节不只计算量大,而且内存带宽也大。随着模型规模的扩展,这种开支也会参与,使得扩展变得越来越艰难。
因此,多查问留意 (MQA) 应运而生,成为缓解这一瓶颈的处置打算。其理念便捷而有效: 经常使用多个查问头,但只经常使用一个键和值头。这种方法清楚缩小了内存负载,提高了推理速度。
MQA(多查问留意力)
图2 MHA和MQA的差异
MQA是MHA的一种变体,也是用于自回归解码的一种留意力机制。,图1、图2很笼统的描述了MHA和MQA的对比,与MHA 不同的是, MQA 让一切的Head之间共享雷同的一份 K 和 V 矩阵(象征K和V的计算惟一),只让 Q 保管了原始多头的性质 (每个Head存在不同的转换),从而大大缩小 K 和 V 矩阵的参数量以及KV Cache的显存占用,以此来到达优化推理速度,但是会带来精度上的损失。MQA被少量运行于LLM中,如ChatGLM2。
左 - 多头留意力,中 - 多查问留意力,右 - 将现有的 MHA 审核点转换为 MQA
如何将现有的预训练多头留意力模型转换为多查问留意力模型 (MQA)? 从现有的多头模型创立多查问留意力模型触及两个步骤:模型结构的转换和随后的预训练。
代码成功
import torchimport torch.nn as nnclass MultiQuerySelfAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiQuerySelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)# MHA# self.wk = nn.Linear(embed_dim, embed_dim)# self.wv = nn.Linear(embed_dim, embed_dim)# MQAself.wk = nn.Linear(embed_dim, self.head_dim)self.wv = nn.Linear(embed_dim, self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def q_h_split(self, hidden, head_num=None):batch_size, seq_len = hidden.size()[:2]# q拆分多头if head_num == None:x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)return xelse:# 这是MQA: 须要拆分k和v,这外面的head_num =1 的# 最终前往维度(batch_size, 1, seq_len, head_dim)return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)def forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分# 这是MHA的# q, k ,v= self.split(q), self.split(k), self.split(v)# 这是MQA的q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)# 留意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 多头兼并output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.rand(3, 12, 512)atten = MultiQuerySelfAttention(512, 8)y = atten(x)print(y.shape)
GQA(分组查问留意力)
只管MQA模式大幅减小了参数数量,但是,带来推理减速的同时会形成模型性能损失,且在训练环节使得模型变得不稳固( 复杂度的降低或者会造成品质降低和训练不稳固 ),因此在此基础上提出了GQA,它将Query启动分组,每个组内共享一组Key、Value。(GQA在LLaMA-2 和 Mistral7B获取运行)
GQA 的数学原理 :
分组:在 GQA 中,传统多头模型中的查问头 (Q) 被分红 G 组。每组调配一个键 (K) 和值 (V) 头。此性能示意为 GQA-G,其中 G 示意组数。
GQA 的不凡状况 :
对每个组边疆始头部的键和值投影矩阵启动均值池化,以将MHA模型转换为 GQA 模型。此技术对组中每个头部的投影矩阵启动平均,从而为该组生成单个键和值投影。
经过 应用 GQA,该模型在 MHA 品质和 MQA 速度之间坚持平衡 。由于键值对较少,内存带宽和数据加载需求被最小化。G 的选用代表了一种掂量:更多的组(更凑近 MHA)可带来更高的品质但性能较慢,而更少的组(凑近 MQA)可提高速度但有就义品质的危险。此外,随着模型规模的扩展,GQA 准许内存带宽和模型容量按比例缩小,与模型规模相对应。相比之下,关于更大的模型,在 MQA 中缩小到单个键和值头或者会过于重大。
代码成功
import torchimport torch.nn as nnclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(GroupedQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.wq = nn.Linear(embed_dim, embed_dim)# 这是MHA的# self.wk = nn.Linear(embed_dim, embed_dim)# self.wv = nn.Linear(embed_dim, embed_dim)# 这是MQA的# self.wk = nn.Linear(embed_dim, self.head_dim)# self.wv = nn.Linear(embed_dim, self.head_dim)# 这是GQA的self.group_num = 4# 这是4个组self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)self.wo = nn.Linear(embed_dim, embed_dim)def split(self, hidden, group_num=None):batch_size, seq_len = hidden.size()[:2]# q须要拆分多头if group_num == None:x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)return xelse:# 这是kv须要拆分的多头x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)return xdef forward(self, hidden_states, mask=None):batch_size = hidden_states.size(0)# 线性变换q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)# 多头切分# 这是MHA的# q, k ,v= self.split(q), self.split(k), self.split(v)# 这是MQA的# q, k ,v= self.split(q), self.split(k, 1), self.split(v, 1)# 这是GQA的q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)# 留意力计算scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))print("scores:", scores.shape)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention = torch.softmax(scores, dim=-1)output = torch.matmul(attention, v)# 兼并多头output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)# 线性变换output = self.wo(output)return outputx = torch.ones(3, 12, 512)atten = GroupedQueryAttention(512, 8)y = atten(x)print(y.shape)
参考文献
原文链接: