train-llm-from-scratch 完整笔记
从零手写 LLM 的教学仓库阅读笔记:经典 GPT 风格预训练(LayerNorm / 绝对位置编码 / ReLU MLP)与现代 LLaMA 风格架构(RMSNorm / RoPE / SwiGLU / GQA)对照,涵盖数据流水线、模型搭建、训练循环、SFT 和 Reasoning 全流程。
仓库地址:https://github.com/FareedKhan-dev/train-llm-from-scratch
一、项目总览#
1.1 这个仓库是什么#
这是一个**从零手写大语言模型(LLM)**的教学项目,作者 FareedKhan-dev,基于论文 Attention is All You Need(2017)用 PyTorch 从头实现 Transformer,目标是让个人用一张消费级 GPU 就能训练出从百万到十亿参数规模的 LLM。
项目的特点是不调用任何高层封装(不用 HuggingFace 的 Trainer、不 import 现成的 GPT2Model),而是把 Embedding、注意力、前馈网络、训练循环、文本生成全部手写。顺着这些代码看下来,基本能把”原始文本 → token → batch → 模型前向 → 生成结果”这条链路串起来。
作者的经验结论:13M(一千三百万)参数是一个分水岭——再小(比如 2.3M 在莎士比亚语料上)输出基本是乱码,而到了 13M 量级,模型开始能产出拼写正确、语法基本通顺的句子。但盲目堆到十亿参数,如果架构深度不够、数据不足,反而可能过拟合,效果不一定比小模型好。
1.2 两套实现的关系#
这个仓库实际上包含两套独立的 LLM 实现,风格和目标都不同:
| 第一套(主代码库) | 第二套(notebook) | |
|---|---|---|
| 位置 | src/ + scripts/ + config/ + data_loader/ | sft_rlhf_guide.ipynb |
| 定位 | 经典 GPT 预训练,工程化、可命令行运行 | 端到端教学,“会思考”的 LLM 全流程 |
| 架构 | LayerNorm + 绝对位置编码 + ReLU MLP + 标准多头注意力 | RMSNorm + RoPE + SwiGLU + GQA + Flash Attention |
| 分词器 | tiktoken 现成的 r50k_base | 用 HF tokenizers 库自己训练 BPE |
| 数据 | The Pile(825GB 真实语料的子集) | notebook 内置的几条玩具样本 |
| 训练阶段 | 仅预训练(pretraining) | 预训练 → SFT → Reasoning 三段式 |
| 框架风格 | 纯 torch.nn.Module | 继承 HF 的 PreTrainedModel / PretrainedConfig |
简单说:
- 想理解 Transformer 最朴素的样子、想真正跑一次大规模预训练 → 看第二部分(主代码库)。
- 想理解现代 LLM(LLaMA 系)长什么样、想搞懂 SFT 和”思维链”训练怎么做 → 看第三部分(notebook)。
两套都从零搭 Transformer,但第二套基本就是一个缩小版的 LLaMA + 后训练 pipeline,很多现在常见的 LLM 组件,notebook 里都有一个简化版实现。
1.3 目录结构导读#
1train-llm-from-scratch/2├── config/3│ └── config.py # 主代码库的全部超参(模型规模、训练参数、路径)4├── data_loader/5│ └── data_loader.py # get_batch_iterator:从 HDF5 流式取 batch6├── scripts/7│ ├── data_download.py # 从 HuggingFace 下载 The Pile 分片8│ ├── data_preprocess.py # jsonl.zst → 分词 → 存成 HDF5 token 流9│ ├── train_transformer.py # 训练主脚本(import 即开始训练)10│ └── generate_text.py # 加载 checkpoint 做自回归生成11├── src/12│ └── models/13│ ├── __init__.py # 导出 MLP / Head / MultiHeadAttention / Block / Transformer14│ ├── mlp.py # 前馈网络15│ ├── attention.py # 单头 Head + 多头 MultiHeadAttention16│ ├── transformer_block.py # 单个 Transformer Block17│ └── transformer.py # 完整模型(embedding → blocks → lm_head → generate)18├── sft_rlhf_guide.ipynb # 第二套:现代架构 + 预训练/SFT/Reasoning 全流程19├── requirements.txt20└── README.md # 含一篇非常详细的 step-by-step 讲解数据和模型产物在运行时生成(仓库里没有提交):
1data/2├── train/ # 训练用 .jsonl.zst 原始分片 + pile_train.h5(分词后)3└── val/ # 验证用 val.jsonl.zst + pile_dev.h54models/ # 训练得到的 .pt checkpoint1.4 环境依赖与快速上手#
requirements.txt 的依赖很轻:
1torch # 深度学习框架2numpy # 数值运算 / 索引打乱3h5py # 读写 HDF5(存分词后的 token 流)4tqdm # 进度条5requests # 下载数据6zstandard # 解压 .zst7tiktoken # OpenAI 的分词器notebook(第二套)额外需要
transformers和tokenizers,不在这个 requirements 里。
主代码库的完整运行四步:
1# 0. 让 Python 找得到项目根目录(否则 from config.config import ... 会失败)2export PYTHONPATH="$PYTHONPATH:."3
4# 1. 下载数据(默认只下 1 个训练分片,每个约 11GB;val 总会下)5python scripts/data_download.py --train_max 16
7# 2. 预处理:分词 + 存 HDF5(默认每个文件只取前 1000 行,方便快速试跑)8python scripts/data_preprocess.py --max_data 10009
10# 3. 训练(超参在 config/config.py 里改)11python scripts/train_transformer.py12
13# 4. 用训练好的 checkpoint 生成文本14python scripts/generate_text.py --model_path models/transformer_B.pt --input_text "Hello"关键提醒:
config/config.py里默认配置是十亿参数级别(N_EMBED=2048, N_BLOCKS=64),单卡基本跑不动。想先跑通流程,务必先把模型调小(见 2.3 配置项详解)。推荐的 13M 配置是N_EMBED=128, N_HEAD=8, N_BLOCKS=1, CONTEXT_LENGTH=128。
二、主代码库:经典 GPT 风格预训练#
这一部分对应 src/ + scripts/ + config/ + data_loader/,下面就按实际跑代码的顺序来看:先准备数据,再看模型、训练和生成。
2.1 数据流水线#
整条数据链路是:HuggingFace 上的 .jsonl.zst 压缩分片 → 解压逐行读 JSON → 取出 text 字段 → tiktoken 分词 → 拼成一条超长 token 流 → 存进 HDF5 → 训练时按 context_length 切片取 batch。
2.1.1 下载数据(scripts/data_download.py)#
数据集是 The Pile(去版权版 monology/pile-uncopyrighted),22 个领域混合的大规模英文语料,原始 825GB。脚本只下载其中一小部分。
URL 是按分片编号拼出来的:
1BASE_URL = "https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main"2VAL_URL = f"{BASE_URL}/val.jsonl.zst" # 验证集(固定一个文件)3TRAIN_URLS = [f"{BASE_URL}/train/{i:02d}.jsonl.zst" for i in range(65)] # 训练集:00~64 共 65 个分片下载本身是标准的流式写盘(stream=True 边下边写,配 tqdm 进度条),每块 1024 字节:
1def download_file(url: str, file_name: str) -> None:2 response = requests.get(url, stream=True)3 total_size = int(response.headers.get('content-length', 0))4 block_size = 10245 with open(file_name, 'wb') as f:6 for chunk in tqdm(response.iter_content(block_size), total=total_size // block_size, desc="Downloading", leave=True):7 f.write(chunk)download_dataset 的逻辑有两个值得注意的设计:
- 可以断点重跑:每个文件下载前先
os.path.exists检查,已存在就跳过。中断后再跑不会重新下载。 - 训练分片可控:
train_urls[:max_train_files]只取前max_train_files个。默认--train_max 1,即只下00.jsonl.zst(约 11GB)。
命令行参数:
| 参数 | 默认 | 含义 |
|---|---|---|
--train_max | 1 | 下载多少个训练分片(最多 65) |
--train_dir | data/train | 训练数据目录 |
--val_dir | data/val | 验证数据目录 |
2.1.2 预处理与分词(scripts/data_preprocess.py)#
这一步把人类可读的文本变成模型能吃的 token id,并存成方便随机读取的 HDF5。核心是 process_files:
1def process_files(input_dir, output_file, tokenizer_name, max_data=None):2 enc = tiktoken.get_encoding(tokenizer_name) # 默认 'r50k_base'(GPT-3 用的分词器)3
4 with h5py.File(output_file, 'w') as out_f:5 # 创建一个一维、可动态扩容的 dataset,存所有 token6 dataset = out_f.create_dataset('tokens', (0,), maxshape=(None,), dtype='i')7 start_index = 08
9 for filename in sorted(os.listdir(input_dir)):10 if filename.endswith(".jsonl.zst"):11 with zstd.open(in_file, 'rt', encoding='utf-8') as in_f: # 流式解压 + 文本模式读12 for line in tqdm(in_f, desc=f"Processing {filename}", total=max_data):13 data = json.loads(line)14 text = data.get('text')15 if text:16 # 每条文本末尾追加 <|endoftext|>,再编码17 encoded = enc.encode(text + "<|endoftext|>", allowed_special={'<|endoftext|>'})18 encoded_len = len(encoded)19 end_index = start_index + encoded_len20 dataset.resize(dataset.shape[0] + encoded_len, axis=0) # 扩容21 dataset[start_index:end_index] = encoded # 追加写入22 start_index = end_index23 processed_lines += 124 if max_data is not None and processed_lines >= max_data:25 break # 每个文件最多处理 max_data 行有三个设计点需要理解:
-
token 流是”一维拼接”而非”二维矩阵”。所有文档的 token 首尾相接,存成一条巨长的一维数组
tokens。文档边界靠<|endoftext|>这个特殊 token 标记——它告诉模型”上一段到此结束”,避免模型把两篇不相关文档当成连续上下文,也是生成时的自然停止信号。 -
为什么用 HDF5。HDF5 支持按切片随机读取且不需要把整个数据集载入内存。训练时要随机取
dataset[idx:idx+context_length+1]这样的片段,HDF5 直接从磁盘读对应区间即可,几十上百 GB 的 token 流也扛得住。maxshape=(None,)表示这一维可以无限扩容,配合resize实现”边读边追加”。 -
max_data主要用于试跑。默认 1000 表示每个分片只处理前 1000 行;正式训练时需要调大或改成处理全部。
| 参数 | 默认 | 含义 |
|---|---|---|
--train_dir / --val_dir | data/train / data/val | 输入目录 |
--out_train_file | data/train/pile_train.h5 | 训练 token 输出 |
--out_val_file | data/val/pile_dev.h5 | 验证 token 输出 |
--tokenizer_name | r50k_base | tiktoken 分词器名 |
--max_data | 1000 | 每个文件最多处理多少行 JSON |
r50k_base是什么:tiktoken 内置的 BPE 分词器,词表 50257,GPT-3/GPT-2 同款。注意它的真实词表大小是 50257,而 config 里VOCAB_SIZE=50304——多出来的 47 个是padding 到 64 的倍数,纯粹为了 GPU 上矩阵运算对齐更高效。这些多出来的 token 不会出现在训练目标里(模型学不到要预测它们),推理采样时被选中的概率也极低,基本不影响使用。
2.1.3 批次迭代器(data_loader/data_loader.py)#
get_batch_iterator 是一个无限生成器,负责从一维 token 流里切出 (输入, 目标) 的 batch,整个训练数据侧都围绕它运转:
1def get_batch_iterator(data_path, batch_size, context_length, device="cpu"):2 with h5py.File(data_path, 'r') as hdf5_file:3 dataset = hdf5_file['tokens']4 dataset_size = dataset.shape[0]5 # 能切出多少个不重叠的样本(-1 是因为目标要往后错一位)6 n_examples = (dataset_size - 1) // context_length7 example_idxs = np.arange(n_examples)8 np.random.shuffle(example_idxs) # 打乱样本顺序9
10 epochs = 011 counter = 012 while True: # 无限循环,训练循环自己控制何时停13 if counter + batch_size > n_examples:14 np.random.shuffle(example_idxs) # 一个 epoch 用完,重新打乱15 counter = 016 print(f"Finished epoch {epochs}")17 epochs += 118
19 # 把"第几个样本"换算成 token 流里的起始下标20 random_indices = example_idxs[counter:counter+batch_size] * context_length21 # 每个样本取 context_length+1 个 token(多取一个用于错位)22 random_samples = torch.tensor(np.array(23 [dataset[idx:idx+context_length+1] for idx in random_indices]24 ))25 xb = random_samples[:, :context_length].to(device) # 输入:前 context_length 个26 yb = random_samples[:, 1:context_length+1].to(device) # 目标:错后一位27 counter += batch_size28 yield xb, ybxb / yb 错一位是最关键的地方,这里能直接看到语言模型的训练目标——用当前位置的输入去预测下一个 token:
1token 流片段: [The] [cat] [sat] [on] [the] [mat]2xb (输入): [The] [cat] [sat] [on] [the] # 位置 0..T-13yb (目标): [cat] [sat] [on] [the] [mat] # 位置 1..T模型在每个位置 t 看到 xb[t],要预测出 yb[t](也就是原序列里 xb[t] 的下一个 token)。所以一个长度 T 的片段不是只提供一个标签,而是同时提供 T 个预测位置。
shape 一览(设 batch=B,context_length=T):
| 变量 | shape | 说明 |
|---|---|---|
random_indices | (B,) | B 个样本在 token 流中的起始位置 |
random_samples | (B, T+1) | 每个样本 T+1 个 token |
xb | (B, T) | 输入 |
yb | (B, T) | 目标(错位) |
两个细节:
- 样本是不重叠切分的:第
i个样本从i * context_length开始,样本之间首尾相接不重叠(不是滑动窗口)。这样数据利用率高、实现简单,代价是每个 token 只会出现在一个固定的上下文位置组合里。 - device 默认
"cpu"。训练脚本调用时会显式传入config['device'](即'cuda'),所以实际训练在 GPU 上。
2.2 模型架构(自底向上)#
主代码库的模型由五个文件、五个类层层组装,依赖关系是:
1Transformer (transformer.py) 完整模型2 └── Block (transformer_block.py) ×N_BLOCKS 层3 ├── MultiHeadAttention (attention.py)4 │ └── Head ×n_head 单个注意力头5 └── MLP (mlp.py) 前馈网络整个前向过程的张量约定是统一的:B = batch size,T = 序列长度(time steps),C = 嵌入维度(n_embed)。下面自底向上拆解。
2.2.1 前馈网络 MLP(src/models/mlp.py)#
MLP 是 Transformer Block 里”思考”的部分——注意力负责”看哪里”,MLP 负责”基于看到的信息做非线性变换”。结构是经典的升维 → 激活 → 降维:
1class MLP(nn.Module):2 def __init__(self, n_embed):3 super().__init__()4 self.hidden = nn.Linear(n_embed, 4 * n_embed) # 升维到 4 倍5 self.relu = nn.ReLU()6 self.proj = nn.Linear(4 * n_embed, n_embed) # 投影回原维度7
8 def forward(self, x):9 x = self.forward_embedding(x) # hidden + relu10 x = self.project_embedding(x) # proj11 return x12
13 def forward_embedding(self, x):14 return self.relu(self.hidden(x))15
16 def project_embedding(self, x):17 return self.proj(x)- 为什么升维到 4 倍:
4 * n_embed是 Transformer 的惯例(原始论文就是 4 倍)。中间维度更大,给模型更多”工作空间”去拟合复杂的非线性映射,再压回原维度保持各层接口一致。 - 为什么拆成
forward_embedding/project_embedding两个方法:把”升维+激活”和”降维投影”拆开,是为了支持后面Block.forward_embedding那种只走一半的中间态前向(拿到激活后、投影前的表示)。常规训练只用forward。 - shape:输入输出都是
(B, T, C),中间隐藏层是(B, T, 4C)。nn.Linear只作用在最后一维,(B, T)原样保留。
2.2.2 单头注意力 Head(src/models/attention.py)#
注意力是模型里最核心的部分。一个 Head 做的事情是:让每个位置的 token 去”查询”序列里它之前所有 token,按相关度加权汇总信息。
1class Head(nn.Module):2 def __init__(self, head_size, n_embed, context_length):3 super().__init__()4 self.key = nn.Linear(n_embed, head_size, bias=False) # K 投影5 self.query = nn.Linear(n_embed, head_size, bias=False) # Q 投影6 self.value = nn.Linear(n_embed, head_size, bias=False) # V 投影7 # 下三角矩阵,注册成 buffer(不是参数,不训练,但随模型存取/搬设备)8 self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))9
10 def forward(self, x):11 B, T, C = x.shape12 head_size = self.key.out_features13 k = self.key(x) # (B, T, head_size)14 q = self.query(x) # (B, T, head_size)15 scale_factor = 1 / math.sqrt(head_size)16 # 注意力分数:q 和 k 做点积17 attn_weights = q @ k.transpose(-2, -1) * scale_factor # (B, T, T)18 # 因果掩码:把"未来"位置置成 -inf19 attn_weights = attn_weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))20 attn_weights = F.softmax(attn_weights, dim=-1) # 每行归一化成概率21 v = self.value(x) # (B, T, head_size)22 out = attn_weights @ v # (B, T, head_size)23 return out对应的数学就是缩放点积注意力(Scaled Dot-Product Attention):
其中 dk 是 head_size,M 是因果掩码矩阵(上三角为 −∞,其余为 0)。逐点拆解:
-
Q、K、V 三个投影。同一份输入
x经过三个独立的线性层,得到 Query(“我想找什么”)、Key(“我有什么特征”)、Value(“我携带什么信息”)。三者都没有 bias——注意力里 bias 作用不大,省去更简洁。 -
打分
q @ k.transpose(-2,-1)。(B,T,head_size) @ (B,head_size,T) → (B,T,T)。结果第i行第j列 = 位置i的 query 和位置j的 key 的点积,衡量”位置i该有多关注位置j”。 -
缩放
1/√head_size。点积会随维度增大而方差变大,把 softmax 推向饱和区(梯度消失)。除以 dk 把分数拉回合理范围。注意缩放用的是
head_size(即1/√d_k),而非整个n_embed——这是原始论文的标准做法。 -
因果掩码
masked_fill(tril==0, -inf)。tril是下三角全 1 矩阵,tril[:T,:T]==0选出上三角(含未来位置)。把这些位置的分数设成-inf,softmax 后权重变成 0——保证位置i只能看到≤ i的 token,看不到未来。这是 decoder-only 语言模型的根本约束(否则预测下一个词时就”作弊偷看答案”了)。 -
softmax 归一化 + 加权求和
attn_weights @ v。每行权重归一化成概率分布,再(B,T,T) @ (B,T,head_size) → (B,T,head_size),把各位置的 Value 按注意力权重加权汇总。
tril用register_buffer而非nn.Parameter:它是固定的掩码常量,不需要梯度更新,但希望它能跟着model.to(device)一起搬到 GPU、跟着state_dict一起存取。buffer 正是为这种”非训练但属于模型状态”的张量设计的。
2.2.3 多头注意力 MultiHeadAttention(src/models/attention.py)#
单个头只能学一种”关注模式”。多头注意力让若干个头并行各看各的(有的头学语法依赖,有的学指代关系……),最后拼接:
1class MultiHeadAttention(nn.Module):2 def __init__(self, n_head, n_embed, context_length):3 super().__init__()4 # n_head 个头,每个头维度 n_embed // n_head5 self.heads = nn.ModuleList(6 [Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)]7 )8
9 def forward(self, x):10 # 每个头独立处理,再沿最后一维拼接11 x = torch.cat([h(x) for h in self.heads], dim=-1)12 return x- 维度切分:每个头的
head_size = n_embed // n_head。n_head个头各输出(B, T, n_embed//n_head),沿最后一维cat回(B, T, n_embed)。总维度不变,相当于把n_embed维”分配”给多个头。 - 设计取舍:这里是最朴素的实现——用 Python 列表存
n_head个独立Head,前向时for循环逐个跑再cat。这种写法比较直观,但效率一般,因为每个 head 都单独跑一遍线性层和注意力。notebook 里的DemoAttention把 Q/K/V 合到一个大矩阵里一次算完,更接近工程上的常见做法。 - ⚠️ 这个实现里没有输出投影层(标准 Transformer 在多头 cat 之后还有一个
W_O线性层)。本仓库省掉了它,直接把拼接结果送出。这是简化,不是 bug,但和标准/现代实现不同,看代码时留意。
2.2.4 Transformer Block(src/models/transformer_block.py)#
一个 Block = 一层”注意力 + 前馈”,配上 Pre-LN(前置层归一化) 和残差连接:
1class Block(nn.Module):2 def __init__(self, n_head, n_embed, context_length):3 super().__init__()4 self.ln1 = nn.LayerNorm(n_embed)5 self.attn = MultiHeadAttention(n_head, n_embed, context_length)6 self.ln2 = nn.LayerNorm(n_embed)7 self.mlp = MLP(n_embed)8
9 def forward(self, x):10 x = x + self.attn(self.ln1(x)) # 注意力子层 + 残差11 x = x + self.mlp(self.ln2(x)) # 前馈子层 + 残差12 return x这两行 forward 信息密度很高,拆开看:
-
残差连接
x = x + sublayer(...)。每个子层学的是”在原表示上加什么修正量”,而不是完全重写。好处是梯度能通过+直接回传到浅层,深网络也能稳定训练(缓解梯度消失)。这是能把 Block 堆到几十层的前提。 -
Pre-LN:先归一化再进子层。注意
LayerNorm作用在self.attn/self.mlp的输入上(attn(ln1(x))),残差加的是未归一化的x。这种”Pre-LN”结构比原始论文的”Post-LN”(ln(x + sublayer(x)))训练更稳定,是现代实现的主流选择。LayerNorm 对每个 token 的C维特征做标准化(减均值除标准差再仿射),让各层输入分布稳定。 -
forward_embedding方法是给前面提到的”半程前向”用的(返回 MLP 投影前的中间态和残差),常规训练/推理不走它。
2.2.5 完整模型 Transformer(src/models/transformer.py)#
把所有零件组装成端到端的语言模型:
1class Transformer(nn.Module):2 def __init__(self, n_head, n_embed, context_length, vocab_size, N_BLOCKS):3 super().__init__()4 self.context_length = context_length5 self.N_BLOCKS = N_BLOCKS6 self.token_embed = nn.Embedding(vocab_size, n_embed) # token → 向量7 self.position_embed = nn.Embedding(context_length, n_embed) # 位置 → 向量8 self.attn_blocks = nn.ModuleList(9 [Block(n_head, n_embed, context_length) for _ in range(N_BLOCKS)]10 )11 self.layer_norm = nn.LayerNorm(n_embed) # 最后一层归一化12 self.lm_head = nn.Linear(n_embed, vocab_size) # 投影到词表,得到 logits13 self.register_buffer('pos_idxs', torch.arange(context_length))14
15 def _pre_attn_pass(self, idx):16 B, T = idx.shape17 tok_embedding = self.token_embed(idx) # (B, T, C)18 pos_embedding = self.position_embed(self.pos_idxs[:T]) # (T, C),广播相加19 return tok_embedding + pos_embedding20
21 def forward(self, idx, targets=None):22 x = self._pre_attn_pass(idx) # 词嵌入 + 位置嵌入23 for block in self.attn_blocks: # 逐层 Transformer Block24 x = block(x)25 x = self.layer_norm(x) # 最终归一化26 logits = self.lm_head(x) # (B, T, vocab_size)27 loss = None28 if targets is not None:29 B, T, C = logits.shape30 flat_logits = logits.view(B * T, C)31 targets = targets.view(B * T).long()32 loss = F.cross_entropy(flat_logits, targets) # 交叉熵33 return logits, loss前向数据流(设 vocab=V):
1idx (B,T) ──token_embed──▶ (B,T,C) ┐2 ├─加──▶ (B,T,C) ──Block×N──▶ (B,T,C) ──LN──▶ ──lm_head──▶ logits (B,T,V)3pos_idxs[:T] ─position_embed─▶ (T,C)┘关键点:
-
两种 embedding 相加。
token_embed告诉模型”这是哪个词”,position_embed告诉它”这个词在第几个位置”。两个查表向量直接相加注入模型。- 这里用的是可学习的绝对位置编码(
nn.Embedding(context_length, n_embed),每个位置一个独立向量,跟着训练)——和原始论文的固定正弦编码不同,也和第三部分 notebook 的 RoPE 完全不同。 pos_idxs[:T]取前T个位置 id,(T,C)通过广播加到每个 batch 上。
- 这里用的是可学习的绝对位置编码(
-
lm_head产出 logits。最后一个Linear把每个位置的C维表示映射到vocab_size维,得到”下一个 token 是词表中各个词的打分”。 -
loss 计算。训练时传入
targets,把(B,T,V)摊平成(B*T, V)、targets摊平成(B*T,),算交叉熵。注意F.cross_entropy内部自带 softmax,所以模型直接输出原始 logits 即可。
生成方法 generate:
1def generate(self, idx, max_new_tokens):2 for _ in range(max_new_tokens):3 idx_cond = idx[:, -self.context_length:] # 只保留最近 context_length 个 token4 logits, _ = self(idx_cond) # 前向5 logits = logits[:, -1, :] # 只取最后一个位置的预测 (B, V)6 probs = F.softmax(logits, dim=-1) # 转概率7 idx_next = torch.multinomial(probs, num_samples=1) # 按概率采样一个 token8 idx = torch.cat((idx, idx_next), dim=1) # 拼回序列,继续下一轮9 return idx这是标准的自回归生成:每次只用模型预测的最后一个位置,按概率分布采样(multinomial,而非贪心取 argmax)出下一个 token,拼回去再喂进模型,循环 max_new_tokens 次。
idx[:, -context_length:]做上下文截断:序列再长,也只喂最近context_length个 token(因为位置编码只有这么多、注意力窗口也只有这么大)。- 用
multinomial采样让输出有随机性和多样性;若改成argmax就是确定性的贪心解码。这里没有温度(temperature)、top-k、top-p 等采样控制——对比第三部分 notebook 的生成就用上了这些。
2.3 配置项详解#
所有超参集中在 config/config.py,最后打包成 default_config 字典供其它脚本 import。逐项含义:
| 参数 | 默认值 | 含义 | 影响 |
|---|---|---|---|
VOCAB_SIZE | 50304 | 词表大小 | 由分词器决定(r50k 实际 50257,padding 到 50304) |
CONTEXT_LENGTH | 512 | 模型支持的最大序列长度 | 决定位置编码表大小、注意力窗口 |
N_EMBED | 2048 | 嵌入维度 C | 影响参数量最大的旋钮之一 |
N_HEAD | 16 | 注意力头数 | 需能整除 N_EMBED |
N_BLOCKS | 64 | Transformer 层数 | 决定模型深度 |
T_BATCH_SIZE | 32 | 训练 batch 大小 | 受显存限制 |
T_CONTEXT_LENGTH | 16 | 训练时实际用的序列长度 | ⚠️ 见下方说明 |
T_TRAIN_STEPS | 200000 | 总训练步数 | |
T_EVAL_STEPS | 1000 | 每多少步评估一次 | |
T_EVAL_ITERS | 250 | 每次评估跑多少个 batch 取平均 | |
T_LR | 5e-4 | 初始学习率 | |
T_LR_DECAYED | 5e-5 | 衰减后学习率 | |
T_LR_DECAY_STEP | 50000 | 在第几步把学习率降到 T_LR_DECAYED | 单次阶梯衰减 |
T_OUT_PATH | models/transformer_B.pt | 模型保存路径 | |
DEVICE | cuda | 运行设备 | |
TRAIN_PATH / DEV_PATH | pile_train.h5 / pile_dev.h5 | 训练/验证数据 |
两组典型配置:
1# 默认(约 21 亿参数,单卡基本跑不动,作者注释误写成 "3 Billion")2N_EMBED = 2048; N_HEAD = 16; N_BLOCKS = 64; CONTEXT_LENGTH = 5123
4# 推荐的 13M(先跑通流程用这个)5N_EMBED = 128; N_HEAD = 8; N_BLOCKS = 1; CONTEXT_LENGTH = 128参数量主要花在哪? 以 13M 配置粗算:
| 部件 | 参数量 | 占比 |
|---|---|---|
token_embed(50304×128) | ≈ 6.44 M | ~49% |
lm_head(128×50304 + bias) | ≈ 6.49 M | ~49% |
| 1 个 Transformer Block | ≈ 0.18 M | ~1.4% |
| 其余(位置编码、最终 LN) | ≈ 0.02 M | <1% |
| 合计 | ≈ 13.1 M | 100% |
一个反直觉但重要的结论:小模型的参数几乎全在 token embedding 和输出层(因为词表有 5 万),真正做”思考”的 Transformer 层只占 1% 多。这也解释了为什么作者说 13M 已能产出通顺文字——语言的”记忆”大量存在 embedding 里。(注:本仓库 token_embed 和 lm_head 是两套独立权重,没有做权重共享;若共享可省下近一半参数。)
⚠️ 一个必须注意的细节:T_CONTEXT_LENGTH=16 vs CONTEXT_LENGTH=512。 模型按 CONTEXT_LENGTH=512 构建(位置编码表有 512 行、注意力能看 512 个 token),但训练循环里 get_batch_iterator 喂的是 t_context_length=16 的短序列(见 2.4)。这意味着实际训练时每条样本只有 16 个 token,只有前 16 个位置编码会被训练到。想让模型真正学会用长上下文,需要把 T_CONTEXT_LENGTH 调大到接近 CONTEXT_LENGTH。这是套用本仓库时容易忽略的点。
2.4 训练循环(scripts/train_transformer.py)#
这个脚本没有 if __name__ == '__main__' 保护,是模块级直接执行——也就是说 python scripts/train_transformer.py 一运行,import 完成的同时训练就开始了。整体分四段。
① 建模型 + 数清参数:
1model = Transformer(2 n_head=config['n_head'], n_embed=config['n_embed'],3 context_length=config['context_length'], vocab_size=config['vocab_size'],4 N_BLOCKS=config['n_blocks']5).to(config['device'])6
7total_params = sum(p.numel() for p in model.parameters())8print(f"Total number of parameters in the model: {total_params:,}")② 优化器 + 评估函数:
1optimizer = torch.optim.AdamW(model.parameters(), lr=config['t_lr'])2losses = []3AVG_WINDOW = 64 # 显示用的滑动平均窗口4
5@torch.no_grad()6def estimate_loss(steps):7 out = {}8 model.eval() # 评估时切 eval 模式9 for split in ['train', 'dev']:10 data_path = config['train_path'] if split == 'train' else config['dev_path']11 batch_iterator_eval = get_batch_iterator(data_path, config['t_batch_size'],12 config['t_context_length'], device=config['device'])13 losses_eval = torch.zeros(steps)14 for k in range(steps):15 xb, yb = next(batch_iterator_eval)16 _, loss = model(xb, yb)17 losses_eval[k] = loss.item()18 out[split] = losses_eval[:k + 1].mean()19 model.train() # 评估完切回 train 模式20 return out@torch.no_grad():评估不需要梯度,省显存、提速。model.eval()/model.train()切换:本模型没有 dropout/BN,切换影响不大,但这是规范写法。- 在 train 和 dev 两个数据集上都评估,方便观察过拟合(train loss 持续降但 dev loss 不降 = 过拟合)。
③ 主训练循环:
1batch_iterator = get_batch_iterator(config['train_path'], config['t_batch_size'],2 config['t_context_length'], device=config['device'])3pbar = tqdm(range(config['t_train_steps']))4for step in pbar:5 try:6 xb, yb = next(batch_iterator) # 取一个 batch7 _, loss = model(xb, yb) # 前向 + 算 loss8 losses.append(loss.item())9 pbar.set_description(f"Train loss: {np.mean(losses[-AVG_WINDOW:]):.4f}")10
11 optimizer.zero_grad(set_to_none=True) # 清梯度12 loss.backward() # 反向传播13 optimizer.step() # 更新参数14
15 if step % config['t_eval_steps'] == 0: # 周期性评估16 evaluation_losses = estimate_loss(config['t_eval_iters'])17 print(f"Step: {step}, Train loss: {...}, Dev loss: {...}")18
19 if step == config['t_lr_decay_step']: # 到点降学习率20 print('Decaying learning rate')21 for g in optimizer.param_groups:22 g['lr'] = config['t_lr_decayed']23 except StopIteration:24 break标准训练四步:前向算 loss → 清零梯度 → 反向求梯度 → 优化器更新。zero_grad(set_to_none=True) 把梯度置 None 而非置 0,省一点内存和计算。
学习率策略是单次硬切换:跑到第 50000 步时,把所有参数组的学习率从 5e-4 直接降到 5e-5,只有一次阶梯下降。notebook 里则用了 warmup 加余弦退火,每步都在调整。
④ 保存 checkpoint:
1os.makedirs(config['t_out_path'].split('/')[0], exist_ok=True) # 确保 models/ 存在2evaluation_losses = estimate_loss(200) # 最终评估3
4# 防止覆盖已有文件:transformer_B.pt 存在就存成 transformer_B_1.pt …5modified_model_out_path = config['t_out_path']6save_tries = 07while os.path.exists(modified_model_out_path):8 save_tries += 19 model_out_name = os.path.splitext(config['t_out_path'])[0]10 modified_model_out_path = model_out_name + f"_{save_tries}" + ".pt"11
12torch.save({13 'model_state_dict': model.state_dict(),14 'optimizer_state_dict': optimizer.state_dict(), # 存优化器状态便于续训15 'losses': losses, # 完整 loss 曲线16 'train_loss': train_loss, 'dev_loss': dev_loss,17 'steps': len(losses),18}, modified_model_out_path)存的不只是模型权重,还有优化器状态、完整 loss 历史、步数——这些为续训和分析保留了所需数据(不过要注意:本脚本本身并没有实现”加载 checkpoint 继续训练”的逻辑,每次运行都是从头初始化),也方便事后画 loss 曲线。保存时如果目标文件已存在,会自动加后缀避免覆盖。
训练循环外层套了
try/except StopIteration:虽然get_batch_iterator是无限生成器理论上不会耗尽,但这层保护能在意外早停时仍走到保存逻辑,避免训练成果丢失。作者报告十亿参数模型最终 train loss 0.2314、dev loss 0.643。
2.5 文本生成(scripts/generate_text.py)#
训练好之后用这个脚本做推理。核心函数 generate_text:
1def generate_text(model_path, input_text, max_new_tokens=100, device='cuda'):2 checkpoint = torch.load(model_path, map_location=torch.device(device))3 model = Transformer( # 用 config 里的超参重建同样结构4 n_head=config['n_head'], n_embed=config['n_embed'],5 context_length=config['context_length'], vocab_size=config['vocab_size'],6 N_BLOCKS=config['n_blocks']7 )8 model.load_state_dict(checkpoint['model_state_dict']) # 灌入训练好的权重9 model.eval().to(device)10
11 enc = tiktoken.get_encoding("r50k_base") # 必须和训练时同一个分词器12 start_ids = enc.encode_ordinary(input_text) # 文本 → token id13 context = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0) # 加 batch 维14
15 with torch.no_grad():16 generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()17 return enc.decode(generated_tokens) # token id → 文本要点:
- 重建结构再灌权重。PyTorch 的
state_dict只存张量数值,不存网络结构。所以必须先用和训练时一模一样的超参实例化Transformer,再load_state_dict。如果config.py改过参数,这里就会 shape 不匹配报错——生成时的 config 必须和训练时一致。 - 分词器必须一致。生成用的
r50k_base要和预处理时同一个,否则 token id 对不上,输出全是乱码。 encode_ordinary:把文本编码成普通 token,不解析<|endoftext|>这类特殊 token(当普通文本处理)。unsqueeze(0)把(T,)变成(1, T)补上 batch 维,因为模型forward期望(B, T)。
命令行用法:
1python scripts/generate_text.py \2 --model_path models/transformer_B.pt \3 --input_text "In 1978" \4 --max_new_tokens 100| 参数 | 默认 | 含义 |
|---|---|---|
--model_path | — | checkpoint 路径 |
--input_text | — | 起始提示词 |
--max_new_tokens | 100 | 生成多少个新 token |
主代码库到这里基本覆盖了下载、预处理、训练和生成整条链路。整体更接近早期 GPT(2017~2019)的实现方式。第三部分的 notebook 会在此基础上换成 LLaMA 风格组件,并加入 SFT 和 reasoning 训练。
三、notebook:现代架构 + 后训练全流程#
sft_rlhf_guide.ipynb 是一份自包含的端到端教程(74 个 cell),标题叫 Building a “Thinking” LLM from Scratch。覆盖的范围比主代码库更大:除了搭一个缩小版 LLaMA 架构,还演示了**预训练 → SFT → Reasoning(思维链)**三个阶段,最终让模型输出 <think>...</think><answer>...</answer> 这种带思考过程的结构化回答。
设计灵感来自开源项目 MiniMind(notebook 注释里多次提及
train_distill_reason.py、MiniMind 的SFTDataset等)。所有数据都是 notebook 内置的几条玩具样本,目的是跑通流程、理解机制,不是真出一个能用的模型。
全局配置先建立认知(这些超参决定模型规模和训练强度):
1DEMO_VOCAB_SIZE = 32000 # 自训分词器的目标词表2DEMO_HIDDEN_SIZE = 1024 # 隐藏维度(相当于主代码库的 n_embed)3DEMO_NUM_LAYERS = 24 # Transformer 层数4DEMO_NUM_ATTENTION_HEADS = 16 # Query 头数5DEMO_NUM_KV_HEADS = 16 # Key/Value 头数(=Q 头数,所以这里实际是 MHA;<16 则是 GQA)6DEMO_MAX_SEQ_LEN = 1024 # 最大序列长度7SPECIAL_TOKENS_LIST = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<pad>"]8
9# 训练超参(三阶段各 10 epoch,学习率递减)10DEMO_BATCH_SIZE = 1611DEMO_PRETRAIN_LR = 3e-412DEMO_SFT_LR = 1e-413DEMO_REASONING_LR= 5e-53.1 自训 BPE Tokenizer#
和主代码库直接用现成的 tiktoken 不同,notebook 从零训练一个 BPE 分词器(用 HuggingFace 的 tokenizers 库)。
BPE(Byte Pair Encoding)原理:从”每个字符是一个 token”开始,反复统计语料里最高频的相邻 token 对,把它合并成一个新 token 加入词表,直到词表达到目标大小。这样高频词(如 “the”)会被合并成单个 token,生僻词则拆成若干子词——既控制了词表规模,又能用已知子词拼出没见过的新词(缓解 OOV 问题)。
训练函数:
1def train_demo_tokenizer(corpus_files, vocab_size, save_path, special_tokens):2 tokenizer_bpe = HFTokenizer(hf_models.BPE(unk_token="<unk>"))3 # ByteLevel:先按字节切分,保证任何 Unicode 字符都能表示(不会真正 OOV)4 tokenizer_bpe.pre_tokenizer = hf_pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)5 tokenizer_bpe.decoder = hf_decoders.ByteLevel()6 trainer = hf_trainers.BpeTrainer(7 vocab_size=vocab_size,8 special_tokens=special_tokens,9 initial_alphabet=hf_pre_tokenizers.ByteLevel.alphabet() # 256 个字节全部作为初始字母表10 )11 tokenizer_bpe.train(corpus_files, trainer=trainer)12 tokenizer_bpe.save(save_path) # 存成单个 .json13 return tokenizer_bpe- ByteLevel 字节级处理是现代分词器(GPT-2、LLaMA)的标配:以 256 个字节为最小单位,任何字符(中文、emoji、罕见符号)都能由字节序列表示,理论上零 OOV。
- 训练语料只有 8 句话(notebook 内置)——纯演示,真实场景需要海量语料才能训出好词表。
四个特殊 token 的角色:
| token | 作用 |
|---|---|
| `< | endoftext |
| `< | im_start |
| `< | im_end |
<pad> | 填充,把不等长序列补齐到同一长度 |
用 AutoTokenizer 重新加载,是为了拿到 HuggingFace 生态的便利功能,最关键的是 apply_chat_template——把对话列表按模板拼成字符串。这里定义的是 ChatML 风格模板:
1<|im_start|>user2{用户内容}<|im_end|>3<|im_start|>assistant4{助手内容}<|im_end|>notebook 里还有一大段
try/exceptfallback(CallableTokenizerWrapper):万一AutoTokenizer加载失败,就手写一个包装类模拟__call__/apply_chat_template/encode/decode。这是健壮性兜底代码,理解主流程时可以跳过,知道”它保证后续tokenizer对象一定可用”即可——后续所有代码都通过统一的tokenizer变量调用分词器。
3.2 数据集与 loss mask#
notebook 定义两个 Dataset 类,分别服务预训练和对话微调。看这两个 Dataset 时主要要盯住 loss mask——它决定哪些 token 参与 loss 计算。
3.2.1 DemoCorpusDataset(预训练用)#
1class DemoCorpusDataset(Dataset):2 def __getitem__(self, idx):3 text = self.samples[idx]4 full_text_with_bos = self.tokenizer.bos_token + text # 句首加 BOS5 encoding = self.tokenizer(full_text_with_bos, max_length=self.max_length,6 padding="max_length", truncation=True, return_tensors='pt')7 input_ids = encoding.input_ids.squeeze(0) # (max_length,)8
9 effective_loss_mask = (input_ids != self.tokenizer.pad_token_id).long() # 非 pad 处为 110 X = input_ids[:-1] # 输入11 Y = input_ids[1:] # 目标(错位一格,同主代码库)12 mask_for_loss_calculation = effective_loss_mask[1:] # 对齐 Y13 return X, Y, mask_for_loss_calculation- 错位逻辑和主代码库
get_batch_iterator完全一致(X=input_ids[:-1],Y=input_ids[1:])。 - loss mask 这里只排除 padding:预训练要学整段文本,所以除了填充的
<pad>,每个 token 都参与 loss。mask 切[1:]是为了和Y对齐。
3.2.2 DemoChatDataset(SFT 和 Reasoning 共用)#
对话数据的处理核心是:只让模型学习 assistant 的回复,不学 user 的提问。靠 loss mask 实现:
1class DemoChatDataset(Dataset):2 def __getitem__(self, idx):3 conversations = self.samples[idx]4 # 用 chat template 把整段对话拼接 + 分词5 input_ids = self.tokenizer.apply_chat_template(6 conversations, tokenize=True, add_generation_prompt=False,7 return_tensors="pt", max_length=self.max_length,8 truncation=True, padding="max_length"9 ).squeeze(0)10
11 loss_mask = torch.zeros_like(input_ids, dtype=torch.long) # 默认全 0(都不算 loss)12 bos_assistant_ids = self.tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)13 eos_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False)14
15 # 扫描 token 序列,把每段 "assistant\n ... <|im_end|>" 之间的 token 标成 116 i = 017 input_ids_list = input_ids.tolist()18 while i < len(input_ids_list):19 # 命中一段 assistant 回复的起始标记20 if input_ids_list[i : i+len(bos_assistant_ids)] == bos_assistant_ids:21 start_of_response = i + len(bos_assistant_ids)22 # 往后找该回复的结束标记 <|im_end|>23 end_marker = -124 j = start_of_response25 while j < len(input_ids_list):26 if input_ids_list[j : j+len(eos_ids)] == eos_ids:27 end_marker = j28 break29 j += 130 if end_marker != -1: # 找到结束标记31 loss_mask[start_of_response : end_marker + len(eos_ids)] = 132 i = end_marker + len(eos_ids) # 跳到这段之后继续扫33 continue34 else: # 没找到(被截断),标到末尾35 loss_mask[start_of_response:] = 136 break37 i += 138
39 loss_mask[input_ids == self.tokenizer.pad_token_id] = 0 # padding 不算40 X = input_ids[:-1]; Y = input_ids[1:]41 mask_for_loss_calculation = loss_mask[1:]42 return X, Y, mask_for_loss_calculation为什么 SFT 要 mask 掉用户输入? 想象训练样本是”用户问 + 助手答”。我们希望模型学的是”给定问题,怎么生成好的回答”,而不是”怎么生成用户的问题”。如果对用户那部分 token 也算 loss,模型会浪费容量去拟合用户输入的分布,甚至学会”自问自答”。所以 mask 只在 assistant 回复区间为 1,loss 只反传这部分。
add_generation_prompt=False:训练时 assistant 的完整回答已经在数据里了,不需要再追加生成提示符(那是推理时才加的)。- 扫描匹配
<|im_start|>assistant\n ... <|im_end|>来定位回复区间——notebook 注释自己也说这是”illustrative”(示意性)的实现,MiniMind 原版用更鲁棒的 token id 直接匹配。 - SFT 和 Reasoning 用的是同一个
DemoChatDataset类,区别只在喂的数据文件不同(SFT 数据是普通问答,Reasoning 数据带<think>/<answer>标签)和 Reasoning 训练循环里额外的”标签加权”(见 3.6)。
3.3 现代架构 DemoLLM#
这一节是 notebook 里最值得细看的部分。DemoLLM 系列类基本就是一个缩小版 LLaMA,把主代码库的经典组件逐一换成 2023+ 的现代做法。整体组装关系:
1DemoLLMForCausalLM # 模型 + lm_head + loss,继承 HF PreTrainedModel2 └── DemoLLMModel # embedding + 层堆叠 + 最终 norm3 └── DemoTransformerBlock ×244 ├── DemoAttention # GQA + RoPE + Flash Attention + KV cache5 │ └── RotaryEmbedding6 ├── DemoFeedForward # SwiGLU7 └── DemoRMSNorm ×2 # Pre-LN可以和主代码库逐项对应:LayerNorm → RMSNorm、绝对位置编码 → RoPE、ReLU MLP → SwiGLU、朴素 MHA → 带 GQA/Flash 的注意力、纯 nn.Module → HF PreTrainedModel。下面逐个看。
3.3.0 配置类 DemoLLMConfig#
继承 HF 的 PretrainedConfig,好处是模型能用 save_pretrained / from_pretrained、能接 HF 的 generate() 等生态工具。几个关键字段:
1self.head_dim = self.hidden_size // self.num_attention_heads # 1024 // 16 = 642self.num_key_value_heads = num_key_value_heads or num_attention_heads3# 约束:Q 头数必须能被 KV 头数整除(GQA 分组的前提)4if num_attention_heads % num_key_value_heads != 0:5 raise ValueError("num_attention_heads 必须能被 num_key_value_heads 整除")6hidden_act = "silu" # SwiGLU 用的激活7rms_norm_eps = 1e-5 # RMSNorm 数值稳定项8rope_theta = 10000.0 # RoPE 的频率基数intermediate_size(FFN 中间维度)的算法值得一提:
1DEMO_INTERMEDIATE_SIZE = int(DEMO_HIDDEN_SIZE * 8 / 3) # 1024 × 8/3 ≈ 27302DEMO_INTERMEDIATE_SIZE = 32 * ((DEMO_INTERMEDIATE_SIZE + 31) // 32) # 向上对齐到 32 的倍数 = 2752- 为什么是 8/3 而不是 4 倍:SwiGLU 比传统 FFN 多一个
gate投影(三个矩阵而非两个),为了让总参数量和传统”4 倍”FFN 大致相当,中间维度取4 × 2/3 = 8/3倍。这是 LLaMA 的标准做法。 - 对齐到 32 的倍数:让矩阵维度对硬件友好(GPU tensor core 喜欢 8/16/32 的整数倍),算得更快。
3.3.1 RMSNorm#
RMSNorm 是 LayerNorm 的简化版,现代 LLM(LLaMA、T5)的主流选择:
1class DemoRMSNorm(nn.Module):2 def __init__(self, dim, eps=1e-5):3 super().__init__()4 self.eps = eps5 self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放,无 bias6
7 def _norm(self, x):8 return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)9
10 def forward(self, x):11 output_dtype = x.dtype12 x = x.to(torch.float32) # 在 float32 下算,保证数值稳定13 output = self._norm(x)14 return (output * self.weight).to(output_dtype)公式:
和 LayerNorm 的区别:LayerNorm 要”减均值、除标准差、再仿射(缩放+平移)”:
RMSNorm 不减均值、没有 bias β,只用均方根(RMS)做缩放。少算一个均值、少一组 bias 参数,速度更快,实际效果通常和 LayerNorm 差不多,所以很多现代 LLM 都采用它。
torch.rsqrt是”1/√x”的融合算子,比先 sqrt 再取倒数快。- 混合精度下的稳定技巧:先
.to(torch.float32)再算 norm,最后转回原 dtype。因为 bf16/fp16 下平方求和容易溢出或损失精度,归一化这种对数值敏感的操作放到 float32 做更稳。
3.3.2 RoPE 旋转位置编码#
RoPE(Rotary Positional Embedding)是现代 LLM 注入位置信息的方式,替代主代码库那种”可学习的绝对位置 embedding 相加”。
核心思想:不再把位置信息加到输入上,而是把每个 token 的 Query/Key 向量按其位置旋转一个角度。位置 m 处的向量旋转 m·θ,位置 n 处旋转 n·θ,两者做点积时,结果只依赖旋转角之差 (m−n)·θ——也就是相对位置。这样模型天然获得相对位置感知,还能外推到比训练时更长的序列。
数学上,把 head_dim 维向量两两配对成 d/2 个二维子向量,对位置 m、第 i 对施加旋转矩阵:
低维对(i 小)旋转快(高频,捕捉近距离关系),高维对旋转慢(低频,捕捉远距离关系)。
代码用复数实现旋转(很优雅,但第一次看会懵):
1class RotaryEmbedding(nn.Module):2 def __init__(self, dim, max_seq_len, theta=10000.0, device=None):3 super().__init__()4 # θ_i = 1 / theta^(2i/d),i = 0,1,...,d/2-15 freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # (d/2,)6 t = torch.arange(max_seq_len).float() # (max_seq_len,)7 freqs = torch.outer(t, freqs) # (max_seq_len, d/2),元素 = m·θ_i8 # 用模长 1、角度 m·θ_i 构造复数 cos(m·θ_i) + i·sin(m·θ_i)9 self.register_buffer("freqs_cis", torch.polar(torch.ones_like(freqs), freqs), persistent=False)10
11 def forward(self, xq, xk, seq_len):12 # xq/xk: (bsz, num_heads, seq_len, head_dim)13 # 把相邻两维拼成复数:(..., head_dim) → (..., head_dim/2) 复数14 xq_c = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))15 xk_c = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))16 freqs_cis_pos = self.freqs_cis[:seq_len].unsqueeze(0).unsqueeze(0) # (1,1,seq_len,d/2)17 # 复数乘法 = 旋转18 xq_out = torch.view_as_real(xq_c * freqs_cis_pos).flatten(3)19 xk_out = torch.view_as_real(xk_c * freqs_cis_pos).flatten(3)20 return xq_out.type_as(xq), xk_out.type_as(xk)理解关键:一个复数乘以单位复数 e^{iθ} = cosθ + i·sinθ,几何上就是在复平面旋转 θ 角。代码把 head_dim 的相邻两维 (x_{2i}, x_{2i+1}) 看成复数 x_{2i} + i·x_{2i+1},乘上 freqs_cis[m] = e^{i·m·θ_i},正好实现了上面旋转矩阵的效果——用一次复数乘法代替 2×2 矩阵乘,简洁高效。
freqs_cis在初始化时一次性算好所有位置×所有频率的旋转因子,存成 buffer(persistent=False表示不写进 checkpoint,加载时重新算)。- RoPE 作用在 Q 和 K 上(不作用于 V),且是在注意力内部、投影之后施加(见下面
DemoAttention)。 - 同样先转
float()算再转回,保证混合精度下的精度。
3.3.3 注意力 DemoAttention#
这个 Attention 比主代码库的 MultiHeadAttention 多了几项现代实现里常见的东西:GQA、RoPE、Flash Attention、KV cache:
1class DemoAttention(nn.Module):2 def __init__(self, config):3 super().__init__()4 self.num_q_heads = config.num_attention_heads # Query 头数 165 self.num_kv_heads = config.num_key_value_heads # KV 头数(可 < Q 头数)6 self.num_kv_groups = self.num_q_heads // self.num_kv_heads # 每组共享的 Q 头数7 self.head_dim = config.head_dim # 648 # 注意 K/V 投影的输出维度是 num_kv_heads × head_dim(可能小于 Q)9 self.q_proj = nn.Linear(hidden, num_q_heads * head_dim, bias=False)10 self.k_proj = nn.Linear(hidden, num_kv_heads * head_dim, bias=False)11 self.v_proj = nn.Linear(hidden, num_kv_heads * head_dim, bias=False)12 self.o_proj = nn.Linear(num_q_heads * head_dim, hidden, bias=False) # ← 输出投影(主代码库没有)13 self.rotary_emb = RotaryEmbedding(head_dim, config.max_position_embeddings, theta=config.rope_theta)14 self.flash_available = hasattr(F, 'scaled_dot_product_attention') and config.flash_attnforward 的几个关键步骤:
1# 辅助方法:把 KV 头复制 n_rep 次,(bsz, num_kv_heads, seq, d) → (bsz, num_kv_heads*n_rep, seq, d)2def _repeat_kv(self, x, n_rep):3 bs, num_kv_heads, slen, head_dim = x.shape4 if n_rep == 1:5 return x6 return (x[:, :, None, :, :]7 .expand(bs, num_kv_heads, n_rep, slen, head_dim)8 .reshape(bs, num_kv_heads * n_rep, slen, head_dim))9
10def forward(self, hidden_states, attention_mask=None, position_ids=None,11 past_key_value=None, use_cache=False):12 bsz, q_len, _ = hidden_states.shape13 # 1) 投影 + 拆头:(bsz, heads, q_len, head_dim)14 query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_q_heads, self.head_dim).transpose(1, 2)15 key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)16 value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)17
18 # 总的 K/V 序列长度(含 KV cache 里的历史 token)19 kv_seq_len = q_len20 if past_key_value is not None:21 kv_seq_len += past_key_value[0].shape[2]22
23 # 2) 给 Q、K 施加 RoPE(按当前新 token 的长度)24 query_states, key_states = self.rotary_emb(query_states, key_states, seq_len=q_len)25
26 # 3) KV cache:把历史的 K/V 拼到前面(推理加速)27 if past_key_value is not None:28 key_states = torch.cat([past_key_value[0], key_states], dim=2)29 value_states = torch.cat([past_key_value[1], value_states], dim=2)30 current_key_value = (key_states, value_states) if use_cache else None31
32 # 4) GQA:把 KV 头复制 num_kv_groups 次,匹配 Q 头数33 key_states = self._repeat_kv(key_states, self.num_kv_groups)34 value_states = self._repeat_kv(value_states, self.num_kv_groups)35
36 # 5) 算注意力:优先 Flash,否则手动37 if self.flash_available and attention_mask is None:38 attn_output = F.scaled_dot_product_attention(39 query_states, key_states, value_states,40 dropout_p=self.config.dropout if self.training else 0.0,41 is_causal=(q_len == kv_seq_len)) # 无 KV cache 时即标准因果注意力42 else:43 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)44 # 因果掩码;causal_shift 兼容 KV cache(q_len < kv_seq_len)的情形45 causal_shift = kv_seq_len - q_len46 mask = torch.triu(torch.full((q_len, kv_seq_len), float("-inf"),47 device=query_states.device), diagonal=1 + causal_shift)48 attn_weights = attn_weights + mask[None, None, :, :]49 if attention_mask is not None: # 额外的 padding 掩码(加性)50 attn_weights = attn_weights + attention_mask51 attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(query_states)52 attn_output = torch.matmul(attn_weights, value_states)53
54 # 6) 合并头 + 输出投影55 attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)56 return self.o_proj(attn_output), current_key_value四个现代特性逐一解释:
-
GQA(Grouped Query Attention,分组查询注意力)。标准 MHA 里 Q、K、V 头数相同;GQA 让多个 Query 头共享同一组 K/V 头。比如 16 个 Q 头但只有 4 个 KV 头,则每 4 个 Q 头共用 1 组 KV。
_repeat_kv在计算时把 KV 头复制num_kv_groups份对齐 Q。- 目的:推理时 KV cache 占大量显存,GQA 把 KV 头数砍小,显存和带宽都省,几乎不掉精度。是 LLaMA-2/3、Mistral 的标配。
- 本 notebook 默认
num_kv_heads = num_q_heads = 16,所以实际退化成标准 MHA——但代码完整支持 GQA,把DEMO_NUM_KV_HEADS调小即可启用。
-
KV cache。自回归生成时,每生成一个新 token 都要重算注意力。但历史 token 的 K/V 是不变的,没必要重算。
past_key_value把历史 K/V 缓存下来,每步只算新 token 的 K/V 再cat上去,把生成复杂度从 O(n²) 降到接近 O(n)。这是推理提速的关键,主代码库的generate完全没有这个优化。 -
Flash Attention。
F.scaled_dot_product_attention是 PyTorch 内置的融合注意力算子,用分块计算避免显式构造(T,T)的大注意力矩阵,大幅省显存、提速。is_causal=True让它内部直接处理因果掩码。没有它时走下面的手动实现(matmul + triu 掩码 + softmax)。 -
因果掩码的细节。手动路径用
torch.triu(..., diagonal=1+causal_shift)构造上三角-inf掩码。causal_shift = kv_seq_len - q_len是为了兼容 KV cache 场景:有缓存时q_len(新 token 数)小于kv_seq_len(总长度),掩码的对角线要相应偏移,保证新 token 能看到所有历史、但仍看不到”更未来”。
3.3.4 前馈网络 SwiGLU#
现代 LLM 的 FFN 几乎都用 SwiGLU 替代经典的”Linear-ReLU-Linear”:
1class DemoFeedForward(nn.Module):2 def __init__(self, config):3 super().__init__()4 self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)5 self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)6 self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)7 self.act_fn = ACT2FN[config.hidden_act] # SiLU8 self.dropout = nn.Dropout(config.dropout)9
10 def forward(self, x):11 return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))公式:
和主代码库 ReLU-MLP 的本质区别是”门控”:经典 MLP 只有”升维→激活→降维”两个矩阵;SwiGLU 有三个矩阵——gate 和 up 把输入都升到中间维度,gate 的输出经 SiLU 激活后作为门控信号,逐元素乘到 up 的输出上,再由 down 降回。这个乘法门让网络能动态地”放大或抑制”每个特征通道的信息,表达力比 ReLU 强,是 LLaMA/PaLM 等模型效果好的因素之一。SiLU(也叫 Swish)是平滑版的 ReLU,处处可导。
3.3.5 Transformer Block#
1class DemoTransformerBlock(nn.Module):2 def __init__(self, config):3 super().__init__()4 self.self_attn = DemoAttention(config)5 self.mlp = DemoFeedForward(config)6 self.input_layernorm = DemoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)7 self.post_attention_layernorm = DemoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)8
9 def forward(self, hidden_states, attention_mask=None, position_ids=None,10 past_key_value=None, use_cache=False):11 residual = hidden_states12 hidden_states = self.input_layernorm(hidden_states)13 attn_out, present_kv = self.self_attn(hidden_states, attention_mask=attention_mask,14 position_ids=position_ids,15 past_key_value=past_key_value, use_cache=use_cache)16 hidden_states = residual + attn_out # 残差 117
18 residual = hidden_states19 hidden_states = self.post_attention_layernorm(hidden_states)20 hidden_states = residual + self.mlp(hidden_states) # 残差 221 return hidden_states, present_kv结构和主代码库的 Block 几乎一模一样——都是 Pre-LN + 双残差。差异仅在具体部件:LayerNorm → DemoRMSNorm、MLP → DemoFeedForward(SwiGLU),以及注意力多返回一个 present_kv(KV cache)。也就是说,Block 的大框架还是 Pre-LN、残差、注意力和 FFN,主要变化发生在 norm、位置编码、FFN 和注意力实现这些部件上。
3.3.6 主干 DemoLLMModel#
1class DemoLLMModel(PreTrainedModel):2 config_class = DemoLLMConfig3
4 def __init__(self, config):5 super().__init__(config)6 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)7 self.layers = nn.ModuleList([DemoTransformerBlock(config) for _ in range(config.num_hidden_layers)])8 self.norm = DemoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)9 self.dropout = nn.Dropout(config.dropout)10
11 def forward(self, input_ids=None, attention_mask=None, position_ids=None,12 past_key_values=None, use_cache=None, **kwargs):13 batch_size, seq_length = input_ids.shape14 past_len = past_key_values[0][0].shape[2] if past_key_values is not None else 015
16 # 没传 position_ids 就自动生成(有 KV cache 时从 past_len 往后排)17 if position_ids is None:18 position_ids = torch.arange(past_len, seq_length + past_len,19 dtype=torch.long, device=input_ids.device).unsqueeze(0)20
21 inputs_embeds = self.embed_tokens(input_ids)22 hidden_states = self.dropout(inputs_embeds)23
24 # 把 (bsz, seq) 的 padding mask 扩成 (bsz,1,q,kv) 加性掩码;因果掩码在 attention 内部处理25 _expanded_mask = None26 if attention_mask is not None:27 expanded = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length + past_len)28 _expanded_mask = torch.zeros_like(expanded, dtype=hidden_states.dtype)29 _expanded_mask.masked_fill_(expanded == 0, float("-inf"))30
31 next_cache = [] if use_cache else None32 for i, layer in enumerate(self.layers):33 past_kv = past_key_values[i] if past_key_values is not None else None34 hidden_states, kv = layer(hidden_states, attention_mask=_expanded_mask,35 position_ids=position_ids, past_key_value=past_kv, use_cache=use_cache)36 if use_cache:37 next_cache.append(kv)38
39 hidden_states = self.norm(hidden_states)40 return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)最重要的差异:没有位置 embedding 层! 对比主代码库 Transformer.__init__ 里有 self.position_embed = nn.Embedding(context_length, n_embed),这里完全没有——因为位置信息已经由 RoPE 在每层注意力内部注入了。这也是 RoPE 和绝对位置 embedding 的主要区别:位置信息不在输入端相加,而是在每层注意力里作用到 Q/K 上。
其余就是标准的”嵌入 → dropout → 逐层 → 最终 RMSNorm”。继承 PreTrainedModel 让它自动获得 HF 的权重初始化、保存加载等能力。
3.3.7 带语言建模头的 DemoLLMForCausalLM#
最外层,加上 lm_head 和 loss 计算:
1class DemoLLMForCausalLM(PreTrainedModel):2 config_class = DemoLLMConfig3
4 def __init__(self, config):5 super().__init__(config)6 self.model = DemoLLMModel(config)7 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)8 # self.model.embed_tokens.weight = self.lm_head.weight # 权重共享(注释掉了,可选)9 self.post_init() # HF 的统一权重初始化10
11 def forward(self, input_ids=None, attention_mask=None, position_ids=None,12 past_key_values=None, labels=None, use_cache=None, **kwargs):13 outputs = self.model(input_ids=input_ids, attention_mask=attention_mask,14 position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache)15 logits = self.lm_head(outputs.last_hidden_state)16 loss = None17 if labels is not None:18 shift_logits = logits[..., :-1, :].contiguous() # 错位:第 t 个位置预测第 t+1 个 token19 shift_labels = labels[..., 1:].contiguous()20 loss = nn.CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size),21 shift_labels.view(-1))22 return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values)23
24 # 供 HF generate() 调用:有 KV cache 时只需喂最后一个 token25 def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):26 if past_key_values:27 input_ids = input_ids[:, -1:]28 return {"input_ids": input_ids, "past_key_values": past_key_values,29 "attention_mask": attention_mask, "use_cache": kwargs.get("use_cache")}三个要点:
- 复用 HF 的
generate()。因为继承了PreTrainedModel并实现了prepare_inputs_for_generation,这个模型直接支持model.generate(..., do_sample=True, temperature=0.7, top_k=10)——温度采样、top-k、KV cache 全部现成可用。对比主代码库要手写generate循环,这里省了大量代码。 - 内置 loss 用了,但训练时没用。
forward在传入labels时会自己算交叉熵(标准的 shift 错位)。但注意:后面三个训练阶段(预训练 / SFT / Reasoning)都不传labels,而是只取logits然后手动用 loss mask 算 loss——因为预训练/SFT/Reasoning 需要不同的 mask 策略,内置 loss 满足不了。 - 权重共享(weight tying)被注释掉了。把
embed_tokens.weight和lm_head.weight绑成同一份是常见省参手段(GPT-2 就这么做),这里保留了选项但默认不启用。
模型规模:hidden=1024, 层数=24。注意 vocab 取的是分词器训练后的实际词表 DEMO_VOCAB_SIZE_FINAL——配置目标是 32000,但在 notebook 仅 8 句的玩具语料下,BPE 实际产出的词表会远小于此,embedding/lm_head 参数也相应小很多。若按名义 vocab=32000 粗算上限约 3.7 亿(≈369M)参数,其中 24 层主干约 304M 才是主体;实际跑出来总量取决于真实词表,会明显小于该上限。代码里的 print_model_summary 会在初始化后打印精确值。
notebook 后半部分按预训练、SFT、Reasoning 三段来组织。每一段都加载上一阶段保存的权重,再用更小的学习率继续训练:
1预训练 (Pretrain) ──保存 demo_llm_pretrained.pth──▶ 学会基本语言规律2 │ load3 ▼4监督微调 (SFT) ──保存 demo_llm_sft.pth──────────▶ 学会跟随指令、对话5 │ load6 ▼7推理训练 (Reasoning)──保存 demo_llm_reasoning.pth─────▶ 学会输出 <think>/<answer> 结构学习率一路递减(3e-4 → 1e-4 → 5e-5):越到后期,改动越精细,避免破坏已学到的能力。
3.4 预训练 Pretraining#
预训练让模型在原始文本上学”预测下一个 token”,掌握基本的词法、语法、常识。训练循环有几个比主代码库更现代的地方:
1optimizer_pt = optim.AdamW(pt_model.parameters(), lr=DEMO_PRETRAIN_LR)2loss_fct = nn.CrossEntropyLoss(reduction='none') # 'none':返回每个 token 的 loss,便于按 mask 加权3
4# 混合精度上下文5autocast_ctx = nullcontext() if DEVICE.type=='cpu' else torch.amp.autocast(device_type=DEVICE.type, dtype=PTDTYPE)6scaler = torch.cuda.amp.GradScaler(enabled=(DTYPE_STR != 'float32' and DEVICE.type == 'cuda'))7
8for epoch in range(DEMO_PRETRAIN_EPOCHS):9 for step, (X_batch, Y_batch, mask_batch) in enumerate(demo_pt_dataloader):10 # 1) 余弦学习率调度,每步更新11 current_lr = get_lr(current_step, total_steps, DEMO_PRETRAIN_LR)12 for g in optimizer_pt.param_groups: g['lr'] = current_lr13
14 with autocast_ctx:15 outputs = pt_model(input_ids=X_batch) # 注意:不传 labels16 logits = outputs.logits # (bsz, seq-1, vocab)17 raw_loss = loss_fct(logits.view(-1, logits.size(-1)), Y_batch.view(-1)) # 每 token loss18 # 用 mask 加权平均:只统计有效 token19 masked_loss = (raw_loss * mask_batch.view(-1)).sum() / mask_batch.sum().clamp(min=1)20
21 # 2) 混合精度反向:scaler 防止 fp16 梯度下溢22 scaler.scale(masked_loss).backward()23 scaler.step(optimizer_pt)24 scaler.update()25 optimizer_pt.zero_grad(set_to_none=True)三个关键技术:
-
混合精度训练(AMP)。
autocast让前向自动用 bf16/fp16 计算(省显存、提速),关键的归一化/loss 仍在 float32。GradScaler在 fp16 下把 loss 放大再反传、更新前缩回,防止小梯度被舍入成 0(下溢)。bf16 动态范围大其实可以不缩放,但代码统一保留了 scaler。这是主代码库(纯 fp32)完全没有的工程优化。 -
余弦退火 + warmup 学习率(
get_lr工具函数)。
1def get_lr(current_step, total_steps, initial_lr, min_lr_ratio=0.1, warmup_ratio=0.01):2 warmup_steps = int(warmup_ratio * total_steps)3 if current_step < warmup_steps:4 return initial_lr * (current_step / warmup_steps) # 线性 warmup5 progress = (current_step - warmup_steps) / (total_steps - warmup_steps)6 coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) # 余弦衰减7 return min_lr + coeff * (initial_lr - min_lr)前 1% 步线性升温(warmup,避免一开始大学习率震坏模型),之后按余弦曲线平滑降到 min_lr(= 初始的 10%)。比主代码库”第 50000 步一刀切”平滑得多,是现代训练标配。
- mask 加权 loss。
CrossEntropyLoss(reduction='none')返回每个位置的 loss,乘以 mask 后求和、再除以 mask 内 token 数——(raw_loss * mask).sum() / mask.sum()。clamp(min=1)防止除零。预训练阶段 mask 只排除 padding,所以等价于”在所有真实 token 上算平均 loss”。这套 mask 加权机制是三个阶段通用的,区别只在 mask 怎么来。
训练完保存 demo_llm_pretrained.pth,并做一次快速生成测试(greedy 解码)。由于数据只有 7 句话,输出必然是复读或胡言——重点是流程跑通。
3.5 SFT 监督微调#
SFT(Supervised Fine-Tuning)让预训练模型学会跟随指令、进行对话。代码和预训练几乎一样,三处不同:
- 加载预训练权重:
sft_model.load_state_dict(torch.load(final_pretrained_model_path))——不是从随机初始化开始,而是在预训练的结果上继续训。 - 换数据集:用
DemoChatDataset+ SFT 对话数据,于是mask_batch变成”只在 assistant 回复处为 1”。同样的(raw_loss * mask).sum() / mask.sum()公式,效果就变成只在助手回复上算 loss——模型只学怎么回答,不学复述用户问题。 - 更小的学习率:
1e-4(预训练是3e-4)。
从代码看,SFT 和预训练共用同一套训练循环,主要差别就在数据格式和 loss mask。保存为 demo_llm_sft.pth,再用 do_sample=True 采样测试问答。
3.6 Reasoning 推理训练#
Reasoning 这一节在 SFT 基础上加了特殊标签加权,让模型学会在回答前先输出显式思考过程,格式是 <think>思考...</think><answer>答案...</answer>:
1# 取出四个标签的"首 token id"2think_start_id = tokenizer.encode('<think>', add_special_tokens=False)[0]3think_end_id = tokenizer.encode('</think>', add_special_tokens=False)[0]4answer_start_id = tokenizer.encode('<answer>', add_special_tokens=False)[0]5answer_end_id = tokenizer.encode('</answer>', add_special_tokens=False)[0]6special_tag_first_token_ids = torch.tensor([7 think_start_id, think_end_id, answer_start_id, answer_end_id8], device=DEVICE).unique()9REASONING_TAG_LOSS_WEIGHT = 5.0 # 标签 token 的 loss 权重放大 5 倍10
11# 训练循环里:12raw_loss = loss_fct(logits.view(-1, vocab), Y_batch.view(-1)) # 每 token loss13
14effective_loss_weights = sft_style_mask.view(-1).float().clone() # 先用 SFT 那套 mask(assistant 区间)15# 找出 target 里属于特殊标签的位置16is_special = torch.isin(Y_batch.view(-1), special_tag_first_token_ids)17# 在"是 assistant token 且是特殊标签"的位置,权重 ×518apply_extra = is_special & (sft_style_mask.view(-1) == 1)19effective_loss_weights[apply_extra] *= REASONING_TAG_LOSS_WEIGHT20
21weighted_loss = (raw_loss * effective_loss_weights).sum() / sft_style_mask.sum().clamp(min=1)为什么要给标签加权? <think>、</think>、<answer>、</answer> 这些结构标签在文本里出现频率很低,但它们是”思考格式”的骨架。如果所有 token 一视同仁,模型很可能学不牢这些低频却关键的标签,生成时格式就乱了。给它们 5 倍 loss 权重,等价于提高这些位置的梯度权重,让模型更优先拟合这几个格式 token,从而更稳定地生成正确的 <think>...</think><answer>...</answer> 结构。
注意分母仍用 sft_style_mask.sum()(原始 mask 的 token 数)归一化,这样 loss 量级和 SFT 阶段大致可比,只是标签位置的梯度被放大了。学习率进一步降到 5e-5。保存为 demo_llm_reasoning.pth。
可以把这看成一种简单的格式约束:仍然是监督训练,只是把结构标签的 loss 权重调高。真实的推理模型(如 DeepSeek-R1)靠强化学习(RL)来激发思考能力,notebook 用的是有监督 + 标签加权来模仿思考的”形”,注释也坦言小模型的 reasoning 更多是模式模仿而非真正推理。
3.7 推理:让模型”思考”#
最后加载 reasoning 模型,用 get_structured_response 跑端到端推理:
1def get_structured_response(model, user_query, max_new_toks=DEMO_MAX_SEQ_LEN - 10, temp=0.7, tk=10):2 chat_history = [{"role": "user", "content": user_query}]3 # add_generation_prompt=True:在末尾补 "<|im_start|>assistant\n" 提示模型轮到它说4 prompt_text = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)5 input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(DEVICE)6
7 generated_ids = model.generate(8 input_ids, max_new_tokens=max_new_toks,9 do_sample=True, temperature=temp, top_k=tk, # 温度采样 + top-k10 eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id11 )12 # 只取新生成的部分(去掉 prompt)13 response = tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)14
15 # 用字符串查找解析出 <think> 和 <answer> 的内容(容错:标签不全时回退到整段)16 think_part, answer_part = "Not found", response17 if "<think>" in response and "</think>" in response:18 think_part = response[response.find("<think>") + len("<think>") : response.find("</think>")].strip()19 if "<answer>" in response and "</answer>" in response:20 answer_part = response[response.find("<answer>") + len("<answer>") : response.find("</answer>")].strip()21 return think_part, answer_part对比主代码库手写的 generate,这里接入 HF 之后少写了不少生成相关代码:
add_generation_prompt=True:推理时数据里只有 user 输入、没有 assistant 回答,需要手动补上<|im_start|>assistant\n来提示模型开始作答。训练时样本里已经有完整回答,所以不用加(False)。- 采样控制:直接用 HF
generate的temperature=0.7(温度,控制随机性,<1 更确定)、top_k=10(只从概率最高的 10 个 token 里采样,避免选到长尾噪声)。主代码库的generate这些都得手写。 - 结构解析:生成后用普通的字符串
find切出<think>和<answer>之间的内容。模型若没学好、标签不完整,代码有多重 fallback(找不到 think 就把全文当 answer)。
notebook 到这里把 tokenizer 训练、模型搭建、三段训练和结构化生成都跑了一遍。虽然数据只是几条玩具样本,但主要流程都能对应到真实 LLM 训练里的环节。
四、横向对比与速查#
4.1 经典 vs 现代架构对照表#
两套实现逐维度并排来看:
| 维度 | 主代码库(经典 GPT 风) | notebook(现代 LLaMA 风) | 为什么换 |
|---|---|---|---|
| 归一化 | LayerNorm(减均值、除标准差、含 bias) | RMSNorm(只除均方根、无 bias) | 更快、参数更少、效果相当 |
| 归一化位置 | Pre-LN | Pre-LN | 都用 Pre-LN(训练稳定) |
| 位置编码 | 可学习的绝对位置 embedding(相加) | RoPE 旋转位置编码(每层注意力内旋转) | 编码相对位置、可外推、省参数 |
| 激活/FFN | ReLU + 2 矩阵、4× 升维 | SwiGLU + 3 矩阵门控、8/3× 升维 | 门控表达力更强 |
| 注意力 | 朴素 MHA(ModuleList 循环每个头) | 融合 QKV + 支持 GQA | 高效、省 KV cache 显存 |
| 输出投影 W_O | 无(直接拼接输出) | 有 o_proj | 标准 Transformer 设计 |
| 注意力加速 | 无 | Flash Attention(scaled_dot_product_attention) | 省显存、提速 |
| KV cache | 无(生成时每步重算全部) | 有(past_key_value) | 生成提速 O(n²)→~O(n) |
| 分词器 | tiktoken r50k_base(现成,词表 50304) | 自训 BPE(tokenizers,词表 ~32000) | 教学:从零训练 |
| 框架 | 纯 torch.nn.Module | HF PreTrainedModel / PretrainedConfig | 复用 generate/save/load |
| 生成采样 | 手写 multinomial(无温度/top-k) | HF generate(温度、top-k、KV cache) | 现成、可控 |
| 训练精度 | fp32 | bf16/fp16 混合精度(autocast + GradScaler) | 省显存、提速 |
| 学习率调度 | 单次阶梯硬切换(50000 步降一次) | 余弦退火 + 线性 warmup(每步更新) | 平滑、收敛更好 |
| 训练阶段 | 仅预训练 | 预训练 → SFT → Reasoning 三段 | 完整后训练 pipeline |
| loss mask | 无(所有 token 都算) | 有(SFT 只算 assistant、Reasoning 加权标签) | 对齐任务目标 |
| 数据 | The Pile(825GB 子集,真实语料) | notebook 内置几条玩具样本 | 教学演示 |
| 数据存储 | HDF5 一维 token 流 | jsonl + PyTorch Dataset | 各自场景 |
默认配置上:主代码库默认 ~21 亿参数(实跑需调小到 13M);notebook 主体是约 304M 的 24 层主干(按名义 vocab=32000 算总量上限约 3.7 亿,实际词表更小、总量也更小)。
4.2 关键概念速查#
读这个仓库会反复碰到的概念,一句话定义:
- 自回归语言建模:模型每一步预测序列的下一个 token,用”输入”和”错位一格的目标”训练(
X=ids[:-1],Y=ids[1:])。两套实现都这么做。 - 因果掩码(Causal Mask):用下三角/上三角
-inf让位置i只能注意到≤ i的 token,保证预测下一个词时看不到未来。decoder-only 模型的根本约束。 <|endoftext|>:标记文档/序列结束的特殊 token,分隔不相关文本、兼当生成停止信号;常同时用作 BOS/EOS。- BOS / EOS / PAD:序列开始 / 结束 / 填充。PAD 把不等长序列补齐到同长度,且必须用 mask 排除在 loss 之外。
- loss mask:一个 0/1(或加权)张量,决定哪些 token 参与 loss。预训练排除 padding;SFT 只保留 assistant 回复;Reasoning 进一步给结构标签加权。
- 残差连接:
x = x + sublayer(x),让梯度直达浅层,是深层网络能训起来的关键。 - Pre-LN:在进子层之前做归一化(
x + attn(norm(x))),比 Post-LN 训练更稳。两套都用。 - RoPE:把 Q/K 向量按位置旋转角度来注入位置信息,点积时自动体现相对位置;现代 LLM 主流位置编码。
- GQA:多个 Query 头共享一组 K/V 头,省 KV cache 显存;
num_kv_heads = num_q_heads时退化为标准 MHA。 - KV cache:生成时缓存历史 token 的 K/V,避免每步重算,大幅提速。
- 混合精度(AMP):前向用 bf16/fp16 算(省显存提速),敏感部分用 fp32;
GradScaler防 fp16 梯度下溢。 - 温度 / top-k 采样:温度
<1让分布更尖锐(更确定),top-k 只从概率最高的 k 个候选里采样(去长尾噪声)。 - 三段式训练:预训练(学语言)→ SFT(学指令/对话)→ Reasoning(学思考格式),逐步加载权重、学习率递减。
4.3 常见坑与注意事项#
实际上手这个仓库容易踩的点,集中列出:
-
PYTHONPATH必须包含项目根目录。scripts/*.py里用from config.config import ...、from src.models...,不设export PYTHONPATH="$PYTHONPATH:."会直接ModuleNotFoundError。 -
默认 config 跑不动。
config/config.py默认是N_EMBED=2048, N_BLOCKS=64(约 21 亿参数),单卡必然 OOM。先跑通务必改成 13M 配置(N_EMBED=128, N_HEAD=8, N_BLOCKS=1, CONTEXT_LENGTH=128)。代码注释里的 “3 Billion” 是笔误(实际约 2.1B)。 -
训练序列长度被
T_CONTEXT_LENGTH=16卡住。模型按CONTEXT_LENGTH=512构建,但训练循环喂的是长度 16 的序列,等于只训练了前 16 个位置。要让模型用上长上下文,得把T_CONTEXT_LENGTH调大。 -
train_transformer.py没有if __name__ == '__main__'保护。它是模块级脚本,import它就会立刻开始训练。不要在别处随意 import。 -
max_data=1000只是试跑设置。data_preprocess.py默认每个文件只处理前 1000 行 JSON。正经训练要调大或处理全部,否则数据量远远不够。 -
分词器和 config 必须前后一致。生成时用的分词器要和预处理时同一个(都
r50k_base),模型超参要和训练时完全相同——否则 token 对不上 / shape 不匹配。 -
notebook 需要额外依赖。
sft_rlhf_guide.ipynb用到transformers和tokenizers,它们不在requirements.txt里,需自行pip install transformers tokenizers。 -
主代码库的有意简化。相比标准/现代 Transformer,主代码库省略了多头注意力的输出投影
W_O、没有 attention/residual dropout、没有权重共享——都是为教学简化,不是 bug,但移植到生产前要补齐。
小结#
这个仓库适合对照着看两类实现:
- 主代码库偏早期 GPT 写法——LayerNorm、绝对位置编码、ReLU、朴素多头注意力,配上 The Pile 真实数据和命令行训练流程。
- notebook 偏 LLaMA 风格——RMSNorm、RoPE、SwiGLU、GQA、Flash Attention、KV cache,再加上预训练 / SFT / Reasoning 三段后训练和思考格式对齐。
两边都手写了主要模块,读起来比较容易看清哪些是 Transformer 的基本骨架(残差、Pre-LN、注意力 + FFN),哪些是后来逐步换上的工程组件,以及预训练之后的 SFT / reasoning 是怎么在同一套训练循环上叠加出来的。
进一步可以探索的方向(notebook 结尾也提到):放大数据和模型规模、引入 MoE、更彻底地用 GQA、分布式训练(DDP/FSDP)、梯度检查点,以及用 RLHF/DPO 做更精细的偏好对齐。
部分内容可能已过时