LLaMA源码解读

1. LLaMA源码解读

这里以Huggingface中的LLaMA源码为例,LLaMA是一个因果模型,不论是做下游任务(例如分类等)还是生成任务,我们都要使用其中的LlamaModel模型作为base模型,此模型的就是初始化词向量以及Transformer模块,下面我们从LlamaModel源码开始看起,并引入其他模块

1.1 LlamaModel

LlamaModel位于Huggingface中transformer包下,路径为transformers/models/llama/modeling_llama.py,源码如下所示,为了简洁,我们只展示__init__初始化部分和两个比较简单的类方法

核心内容

  • LlamaModel中包含两个属性,padding_idxvocab_size,其中padding_idx是指定的padding的token,vocab_size是词表的大小
  • embed_tokens是一个词向量查找表,经过tokenizer的句子会变为一个个token,词表就是根据token的编号来获取每个token的词向量,其中第一个参数为词的数量,第二个为词向量的维度,第三个为padding的token编号(当mini-batch大于1时,对于长度按不同的batch,会用padding_idx的编号进行填充)
  • layers是核心的transformer layer了,llama中的layer使用的是LlamaDecoderLayer类,一共32层,随后会进行介绍
  • norm是一个LlamaRMSNorm,与常规的layernorm略有不同

在前向过程中,LlamaModel的作用是获取词向量并进行attention操作,最终LlamaModel的输出是hiddenstate,方便后续的任务使用,例如可以加一个head来进行分类任务等等

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

Args:
config: LlamaConfig
"""

def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

1.2 LlamaPreTrainedModel

由于LlamaModel继承自LlamaPreTrainedModel,因此这里我们先介绍一下其父类

LlamaPreTrainedModel继承自PreTrainedModel,这个类也是所有模型的父类,这里的LlamaPreTrainedModel重写了_init_weights_set_gradient_checkpointing两个方法,这里不做过多介绍

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value

1.3 LlamaDecoderLayer

LlamaDecoderLayer就是Transformer block,在llama-7b中一共有32层,如下所示,他一共包含几个关键模块

  • hidden_size就是词向量嵌入维度,即transformer维度的大小
  • self_attnLlamaAttention,这个模块是transformer中的attention操作
  • mlpLlamaMLP,这个模块是transformer中的FFN全连接层
  • input_layernormLlamaRMSNorm,这个模块是transformer中的归一化操作
  • post_attention_layernormLlamaRMSNorm,这个模块也是transformer中的归一化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

return outputs
Error: API rate limit exceeded for 54.209.88.255. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)