train-llm-from-scratch 完整笔记

14691 字
73 分钟
train-llm-from-scratch 完整笔记
文章摘要

从零手写 LLM 的教学仓库阅读笔记:经典 GPT 风格预训练(LayerNorm / 绝对位置编码 / ReLU MLP)与现代 LLaMA 风格架构(RMSNorm / RoPE / SwiGLU / GQA)对照,涵盖数据流水线、模型搭建、训练循环、SFT 和 Reasoning 全流程。

仓库地址:https://github.com/FareedKhan-dev/train-llm-from-scratch


一、项目总览#

1.1 这个仓库是什么#

这是一个**从零手写大语言模型(LLM)**的教学项目,作者 FareedKhan-dev,基于论文 Attention is All You Need(2017)用 PyTorch 从头实现 Transformer,目标是让个人用一张消费级 GPU 就能训练出从百万到十亿参数规模的 LLM。

项目的特点是不调用任何高层封装(不用 HuggingFace 的 Trainer、不 import 现成的 GPT2Model),而是把 Embedding、注意力、前馈网络、训练循环、文本生成全部手写。顺着这些代码看下来,基本能把”原始文本 → token → batch → 模型前向 → 生成结果”这条链路串起来。

作者的经验结论:13M(一千三百万)参数是一个分水岭——再小(比如 2.3M 在莎士比亚语料上)输出基本是乱码,而到了 13M 量级,模型开始能产出拼写正确、语法基本通顺的句子。但盲目堆到十亿参数,如果架构深度不够、数据不足,反而可能过拟合,效果不一定比小模型好。

1.2 两套实现的关系#

这个仓库实际上包含两套独立的 LLM 实现,风格和目标都不同:

第一套(主代码库)第二套(notebook)
位置src/ + scripts/ + config/ + data_loader/sft_rlhf_guide.ipynb
定位经典 GPT 预训练,工程化、可命令行运行端到端教学,“会思考”的 LLM 全流程
架构LayerNorm + 绝对位置编码 + ReLU MLP + 标准多头注意力RMSNorm + RoPE + SwiGLU + GQA + Flash Attention
分词器tiktoken 现成的 r50k_base用 HF tokenizers自己训练 BPE
数据The Pile(825GB 真实语料的子集)notebook 内置的几条玩具样本
训练阶段仅预训练(pretraining)预训练 → SFT → Reasoning 三段式
框架风格torch.nn.Module继承 HF 的 PreTrainedModel / PretrainedConfig

简单说:

  • 想理解 Transformer 最朴素的样子、想真正跑一次大规模预训练 → 看第二部分(主代码库)
  • 想理解现代 LLM(LLaMA 系)长什么样、想搞懂 SFT 和”思维链”训练怎么做 → 看第三部分(notebook)

两套都从零搭 Transformer,但第二套基本就是一个缩小版的 LLaMA + 后训练 pipeline,很多现在常见的 LLM 组件,notebook 里都有一个简化版实现。

1.3 目录结构导读#

train-llm-from-scratch/
├── config/
│ └── config.py # 主代码库的全部超参(模型规模、训练参数、路径)
├── data_loader/
│ └── data_loader.py # get_batch_iterator:从 HDF5 流式取 batch
├── scripts/
│ ├── data_download.py # 从 HuggingFace 下载 The Pile 分片
│ ├── data_preprocess.py # jsonl.zst → 分词 → 存成 HDF5 token 流
│ ├── train_transformer.py # 训练主脚本(import 即开始训练)
│ └── generate_text.py # 加载 checkpoint 做自回归生成
├── src/
│ └── models/
│ ├── __init__.py # 导出 MLP / Head / MultiHeadAttention / Block / Transformer
│ ├── mlp.py # 前馈网络
│ ├── attention.py # 单头 Head + 多头 MultiHeadAttention
│ ├── transformer_block.py # 单个 Transformer Block
│ └── transformer.py # 完整模型(embedding → blocks → lm_head → generate)
├── sft_rlhf_guide.ipynb # 第二套:现代架构 + 预训练/SFT/Reasoning 全流程
├── requirements.txt
└── README.md # 含一篇非常详细的 step-by-step 讲解

数据和模型产物在运行时生成(仓库里没有提交):

data/
├── train/ # 训练用 .jsonl.zst 原始分片 + pile_train.h5(分词后)
└── val/ # 验证用 val.jsonl.zst + pile_dev.h5
models/ # 训练得到的 .pt checkpoint

1.4 环境依赖与快速上手#

requirements.txt 的依赖很轻:

torch # 深度学习框架
numpy # 数值运算 / 索引打乱
h5py # 读写 HDF5(存分词后的 token 流)
tqdm # 进度条
requests # 下载数据
zstandard # 解压 .zst
tiktoken # OpenAI 的分词器

notebook(第二套)额外需要 transformerstokenizers,不在这个 requirements 里。

主代码库的完整运行四步:

Terminal window
# 0. 让 Python 找得到项目根目录(否则 from config.config import ... 会失败)
export PYTHONPATH="$PYTHONPATH:."
# 1. 下载数据(默认只下 1 个训练分片,每个约 11GB;val 总会下)
python scripts/data_download.py --train_max 1
# 2. 预处理:分词 + 存 HDF5(默认每个文件只取前 1000 行,方便快速试跑)
python scripts/data_preprocess.py --max_data 1000
# 3. 训练(超参在 config/config.py 里改)
python scripts/train_transformer.py
# 4. 用训练好的 checkpoint 生成文本
python scripts/generate_text.py --model_path models/transformer_B.pt --input_text "Hello"

关键提醒config/config.py 里默认配置是十亿参数级别(N_EMBED=2048, N_BLOCKS=64),单卡基本跑不动。想先跑通流程,务必先把模型调小(见 2.3 配置项详解)。推荐的 13M 配置是 N_EMBED=128, N_HEAD=8, N_BLOCKS=1, CONTEXT_LENGTH=128


二、主代码库:经典 GPT 风格预训练#

这一部分对应 src/ + scripts/ + config/ + data_loader/,下面就按实际跑代码的顺序来看:先准备数据,再看模型、训练和生成。

2.1 数据流水线#

整条数据链路是:HuggingFace 上的 .jsonl.zst 压缩分片 → 解压逐行读 JSON → 取出 text 字段 → tiktoken 分词 → 拼成一条超长 token 流 → 存进 HDF5 → 训练时按 context_length 切片取 batch

2.1.1 下载数据(scripts/data_download.py#

数据集是 The Pile(去版权版 monology/pile-uncopyrighted),22 个领域混合的大规模英文语料,原始 825GB。脚本只下载其中一小部分。

URL 是按分片编号拼出来的:

BASE_URL = "https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main"
VAL_URL = f"{BASE_URL}/val.jsonl.zst" # 验证集(固定一个文件)
TRAIN_URLS = [f"{BASE_URL}/train/{i:02d}.jsonl.zst" for i in range(65)] # 训练集:00~64 共 65 个分片

下载本身是标准的流式写盘(stream=True 边下边写,配 tqdm 进度条),每块 1024 字节:

def download_file(url: str, file_name: str) -> None:
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
with open(file_name, 'wb') as f:
for chunk in tqdm(response.iter_content(block_size), total=total_size // block_size, desc="Downloading", leave=True):
f.write(chunk)

download_dataset 的逻辑有两个值得注意的设计:

  1. 可以断点重跑:每个文件下载前先 os.path.exists 检查,已存在就跳过。中断后再跑不会重新下载。
  2. 训练分片可控train_urls[:max_train_files] 只取前 max_train_files 个。默认 --train_max 1,即只下 00.jsonl.zst(约 11GB)。

命令行参数:

参数默认含义
--train_max1下载多少个训练分片(最多 65)
--train_dirdata/train训练数据目录
--val_dirdata/val验证数据目录

2.1.2 预处理与分词(scripts/data_preprocess.py#

这一步把人类可读的文本变成模型能吃的 token id,并存成方便随机读取的 HDF5。核心是 process_files

def process_files(input_dir, output_file, tokenizer_name, max_data=None):
enc = tiktoken.get_encoding(tokenizer_name) # 默认 'r50k_base'(GPT-3 用的分词器)
with h5py.File(output_file, 'w') as out_f:
# 创建一个一维、可动态扩容的 dataset,存所有 token
dataset = out_f.create_dataset('tokens', (0,), maxshape=(None,), dtype='i')
start_index = 0
for filename in sorted(os.listdir(input_dir)):
if filename.endswith(".jsonl.zst"):
with zstd.open(in_file, 'rt', encoding='utf-8') as in_f: # 流式解压 + 文本模式读
for line in tqdm(in_f, desc=f"Processing {filename}", total=max_data):
data = json.loads(line)
text = data.get('text')
if text:
# 每条文本末尾追加 <|endoftext|>,再编码
encoded = enc.encode(text + "<|endoftext|>", allowed_special={'<|endoftext|>'})
encoded_len = len(encoded)
end_index = start_index + encoded_len
dataset.resize(dataset.shape[0] + encoded_len, axis=0) # 扩容
dataset[start_index:end_index] = encoded # 追加写入
start_index = end_index
processed_lines += 1
if max_data is not None and processed_lines >= max_data:
break # 每个文件最多处理 max_data 行

有三个设计点需要理解:

  1. token 流是”一维拼接”而非”二维矩阵”。所有文档的 token 首尾相接,存成一条巨长的一维数组 tokens。文档边界靠 <|endoftext|> 这个特殊 token 标记——它告诉模型”上一段到此结束”,避免模型把两篇不相关文档当成连续上下文,也是生成时的自然停止信号。

  2. 为什么用 HDF5。HDF5 支持按切片随机读取且不需要把整个数据集载入内存。训练时要随机取 dataset[idx:idx+context_length+1] 这样的片段,HDF5 直接从磁盘读对应区间即可,几十上百 GB 的 token 流也扛得住。maxshape=(None,) 表示这一维可以无限扩容,配合 resize 实现”边读边追加”。

  3. max_data 主要用于试跑。默认 1000 表示每个分片只处理前 1000 行;正式训练时需要调大或改成处理全部。

参数默认含义
--train_dir / --val_dirdata/train / data/val输入目录
--out_train_filedata/train/pile_train.h5训练 token 输出
--out_val_filedata/val/pile_dev.h5验证 token 输出
--tokenizer_namer50k_basetiktoken 分词器名
--max_data1000每个文件最多处理多少行 JSON

r50k_base 是什么:tiktoken 内置的 BPE 分词器,词表 50257,GPT-3/GPT-2 同款。注意它的真实词表大小是 50257,而 config 里 VOCAB_SIZE=50304——多出来的 47 个是padding 到 64 的倍数,纯粹为了 GPU 上矩阵运算对齐更高效。这些多出来的 token 不会出现在训练目标里(模型学不到要预测它们),推理采样时被选中的概率也极低,基本不影响使用。

2.1.3 批次迭代器(data_loader/data_loader.py#

get_batch_iterator 是一个无限生成器,负责从一维 token 流里切出 (输入, 目标) 的 batch,整个训练数据侧都围绕它运转:

def get_batch_iterator(data_path, batch_size, context_length, device="cpu"):
with h5py.File(data_path, 'r') as hdf5_file:
dataset = hdf5_file['tokens']
dataset_size = dataset.shape[0]
# 能切出多少个不重叠的样本(-1 是因为目标要往后错一位)
n_examples = (dataset_size - 1) // context_length
example_idxs = np.arange(n_examples)
np.random.shuffle(example_idxs) # 打乱样本顺序
epochs = 0
counter = 0
while True: # 无限循环,训练循环自己控制何时停
if counter + batch_size > n_examples:
np.random.shuffle(example_idxs) # 一个 epoch 用完,重新打乱
counter = 0
print(f"Finished epoch {epochs}")
epochs += 1
# 把"第几个样本"换算成 token 流里的起始下标
random_indices = example_idxs[counter:counter+batch_size] * context_length
# 每个样本取 context_length+1 个 token(多取一个用于错位)
random_samples = torch.tensor(np.array(
[dataset[idx:idx+context_length+1] for idx in random_indices]
))
xb = random_samples[:, :context_length].to(device) # 输入:前 context_length 个
yb = random_samples[:, 1:context_length+1].to(device) # 目标:错后一位
counter += batch_size
yield xb, yb

xb / yb 错一位是最关键的地方,这里能直接看到语言模型的训练目标——用当前位置的输入去预测下一个 token:

token 流片段: [The] [cat] [sat] [on] [the] [mat]
xb (输入): [The] [cat] [sat] [on] [the] # 位置 0..T-1
yb (目标): [cat] [sat] [on] [the] [mat] # 位置 1..T

模型在每个位置 t 看到 xb[t],要预测出 yb[t](也就是原序列里 xb[t] 的下一个 token)。所以一个长度 T 的片段不是只提供一个标签,而是同时提供 T 个预测位置。

shape 一览(设 batch=B,context_length=T):

变量shape说明
random_indices(B,)B 个样本在 token 流中的起始位置
random_samples(B, T+1)每个样本 T+1 个 token
xb(B, T)输入
yb(B, T)目标(错位)

两个细节:

  • 样本是不重叠切分的:第 i 个样本从 i * context_length 开始,样本之间首尾相接不重叠(不是滑动窗口)。这样数据利用率高、实现简单,代价是每个 token 只会出现在一个固定的上下文位置组合里。
  • device 默认 "cpu"。训练脚本调用时会显式传入 config['device'](即 'cuda'),所以实际训练在 GPU 上。

2.2 模型架构(自底向上)#

主代码库的模型由五个文件、五个类层层组装,依赖关系是:

Transformer (transformer.py) 完整模型
└── Block (transformer_block.py) ×N_BLOCKS 层
├── MultiHeadAttention (attention.py)
│ └── Head ×n_head 单个注意力头
└── MLP (mlp.py) 前馈网络

整个前向过程的张量约定是统一的:B = batch size,T = 序列长度(time steps),C = 嵌入维度(n_embed。下面自底向上拆解。

2.2.1 前馈网络 MLP(src/models/mlp.py#

MLP 是 Transformer Block 里”思考”的部分——注意力负责”看哪里”,MLP 负责”基于看到的信息做非线性变换”。结构是经典的升维 → 激活 → 降维

class MLP(nn.Module):
def __init__(self, n_embed):
super().__init__()
self.hidden = nn.Linear(n_embed, 4 * n_embed) # 升维到 4 倍
self.relu = nn.ReLU()
self.proj = nn.Linear(4 * n_embed, n_embed) # 投影回原维度
def forward(self, x):
x = self.forward_embedding(x) # hidden + relu
x = self.project_embedding(x) # proj
return x
def forward_embedding(self, x):
return self.relu(self.hidden(x))
def project_embedding(self, x):
return self.proj(x)
  • 为什么升维到 4 倍4 * n_embed 是 Transformer 的惯例(原始论文就是 4 倍)。中间维度更大,给模型更多”工作空间”去拟合复杂的非线性映射,再压回原维度保持各层接口一致。
  • 为什么拆成 forward_embedding / project_embedding 两个方法:把”升维+激活”和”降维投影”拆开,是为了支持后面 Block.forward_embedding 那种只走一半的中间态前向(拿到激活后、投影前的表示)。常规训练只用 forward
  • shape:输入输出都是 (B, T, C),中间隐藏层是 (B, T, 4C)nn.Linear 只作用在最后一维,(B, T) 原样保留。

2.2.2 单头注意力 Head(src/models/attention.py#

注意力是模型里最核心的部分。一个 Head 做的事情是:让每个位置的 token 去”查询”序列里它之前所有 token,按相关度加权汇总信息。

class Head(nn.Module):
def __init__(self, head_size, n_embed, context_length):
super().__init__()
self.key = nn.Linear(n_embed, head_size, bias=False) # K 投影
self.query = nn.Linear(n_embed, head_size, bias=False) # Q 投影
self.value = nn.Linear(n_embed, head_size, bias=False) # V 投影
# 下三角矩阵,注册成 buffer(不是参数,不训练,但随模型存取/搬设备)
self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))
def forward(self, x):
B, T, C = x.shape
head_size = self.key.out_features
k = self.key(x) # (B, T, head_size)
q = self.query(x) # (B, T, head_size)
scale_factor = 1 / math.sqrt(head_size)
# 注意力分数:q 和 k 做点积
attn_weights = q @ k.transpose(-2, -1) * scale_factor # (B, T, T)
# 因果掩码:把"未来"位置置成 -inf
attn_weights = attn_weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1) # 每行归一化成概率
v = self.value(x) # (B, T, head_size)
out = attn_weights @ v # (B, T, head_size)
return out

对应的数学就是缩放点积注意力(Scaled Dot-Product Attention):

Attention(Q,K,V)=softmax ⁣(QKdk+M)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}} + M\right) V

其中 dkd_khead_sizeMM 是因果掩码矩阵(上三角为 -\infty,其余为 0)。逐点拆解:

  1. Q、K、V 三个投影。同一份输入 x 经过三个独立的线性层,得到 Query(“我想找什么”)、Key(“我有什么特征”)、Value(“我携带什么信息”)。三者都没有 bias——注意力里 bias 作用不大,省去更简洁。

  2. 打分 q @ k.transpose(-2,-1)(B,T,head_size) @ (B,head_size,T) → (B,T,T)。结果第 i 行第 j 列 = 位置 i 的 query 和位置 j 的 key 的点积,衡量”位置 i 该有多关注位置 j”。

  3. 缩放 1/√head_size。点积会随维度增大而方差变大,把 softmax 推向饱和区(梯度消失)。除以 dk\sqrt{d_k} 把分数拉回合理范围。

    注意缩放用的是 head_size(即 1/√d_k),而非整个 n_embed——这是原始论文的标准做法。

  4. 因果掩码 masked_fill(tril==0, -inf)tril 是下三角全 1 矩阵,tril[:T,:T]==0 选出上三角(含未来位置)。把这些位置的分数设成 -inf,softmax 后权重变成 0——保证位置 i 只能看到 ≤ i 的 token,看不到未来。这是 decoder-only 语言模型的根本约束(否则预测下一个词时就”作弊偷看答案”了)。

  5. softmax 归一化 + 加权求和 attn_weights @ v。每行权重归一化成概率分布,再 (B,T,T) @ (B,T,head_size) → (B,T,head_size),把各位置的 Value 按注意力权重加权汇总。

trilregister_buffer 而非 nn.Parameter:它是固定的掩码常量,不需要梯度更新,但希望它能跟着 model.to(device) 一起搬到 GPU、跟着 state_dict 一起存取。buffer 正是为这种”非训练但属于模型状态”的张量设计的。

2.2.3 多头注意力 MultiHeadAttention(src/models/attention.py#

单个头只能学一种”关注模式”。多头注意力让若干个头并行各看各的(有的头学语法依赖,有的学指代关系……),最后拼接:

class MultiHeadAttention(nn.Module):
def __init__(self, n_head, n_embed, context_length):
super().__init__()
# n_head 个头,每个头维度 n_embed // n_head
self.heads = nn.ModuleList(
[Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)]
)
def forward(self, x):
# 每个头独立处理,再沿最后一维拼接
x = torch.cat([h(x) for h in self.heads], dim=-1)
return x
  • 维度切分:每个头的 head_size = n_embed // n_headn_head 个头各输出 (B, T, n_embed//n_head),沿最后一维 cat(B, T, n_embed)。总维度不变,相当于把 n_embed 维”分配”给多个头。
  • 设计取舍:这里是最朴素的实现——用 Python 列表存 n_head 个独立 Head,前向时 for 循环逐个跑再 cat。这种写法比较直观,但效率一般,因为每个 head 都单独跑一遍线性层和注意力。notebook 里的 DemoAttention 把 Q/K/V 合到一个大矩阵里一次算完,更接近工程上的常见做法。
  • ⚠️ 这个实现里没有输出投影层(标准 Transformer 在多头 cat 之后还有一个 W_O 线性层)。本仓库省掉了它,直接把拼接结果送出。这是简化,不是 bug,但和标准/现代实现不同,看代码时留意。

2.2.4 Transformer Block(src/models/transformer_block.py#

一个 Block = 一层”注意力 + 前馈”,配上 Pre-LN(前置层归一化)残差连接

class Block(nn.Module):
def __init__(self, n_head, n_embed, context_length):
super().__init__()
self.ln1 = nn.LayerNorm(n_embed)
self.attn = MultiHeadAttention(n_head, n_embed, context_length)
self.ln2 = nn.LayerNorm(n_embed)
self.mlp = MLP(n_embed)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # 注意力子层 + 残差
x = x + self.mlp(self.ln2(x)) # 前馈子层 + 残差
return x

这两行 forward 信息密度很高,拆开看:

  • 残差连接 x = x + sublayer(...)。每个子层学的是”在原表示上加什么修正量”,而不是完全重写。好处是梯度能通过 + 直接回传到浅层,深网络也能稳定训练(缓解梯度消失)。这是能把 Block 堆到几十层的前提。

  • Pre-LN:先归一化再进子层。注意 LayerNorm 作用在 self.attn/self.mlp输入上(attn(ln1(x))),残差加的是未归一化x。这种”Pre-LN”结构比原始论文的”Post-LN”(ln(x + sublayer(x)))训练更稳定,是现代实现的主流选择。LayerNorm 对每个 token 的 C 维特征做标准化(减均值除标准差再仿射),让各层输入分布稳定。

  • forward_embedding 方法是给前面提到的”半程前向”用的(返回 MLP 投影前的中间态和残差),常规训练/推理不走它。

2.2.5 完整模型 Transformer(src/models/transformer.py#

把所有零件组装成端到端的语言模型:

class Transformer(nn.Module):
def __init__(self, n_head, n_embed, context_length, vocab_size, N_BLOCKS):
super().__init__()
self.context_length = context_length
self.N_BLOCKS = N_BLOCKS
self.token_embed = nn.Embedding(vocab_size, n_embed) # token → 向量
self.position_embed = nn.Embedding(context_length, n_embed) # 位置 → 向量
self.attn_blocks = nn.ModuleList(
[Block(n_head, n_embed, context_length) for _ in range(N_BLOCKS)]
)
self.layer_norm = nn.LayerNorm(n_embed) # 最后一层归一化
self.lm_head = nn.Linear(n_embed, vocab_size) # 投影到词表,得到 logits
self.register_buffer('pos_idxs', torch.arange(context_length))
def _pre_attn_pass(self, idx):
B, T = idx.shape
tok_embedding = self.token_embed(idx) # (B, T, C)
pos_embedding = self.position_embed(self.pos_idxs[:T]) # (T, C),广播相加
return tok_embedding + pos_embedding
def forward(self, idx, targets=None):
x = self._pre_attn_pass(idx) # 词嵌入 + 位置嵌入
for block in self.attn_blocks: # 逐层 Transformer Block
x = block(x)
x = self.layer_norm(x) # 最终归一化
logits = self.lm_head(x) # (B, T, vocab_size)
loss = None
if targets is not None:
B, T, C = logits.shape
flat_logits = logits.view(B * T, C)
targets = targets.view(B * T).long()
loss = F.cross_entropy(flat_logits, targets) # 交叉熵
return logits, loss

前向数据流(设 vocab=V):

idx (B,T) ──token_embed──▶ (B,T,C) ┐
├─加──▶ (B,T,C) ──Block×N──▶ (B,T,C) ──LN──▶ ──lm_head──▶ logits (B,T,V)
pos_idxs[:T] ─position_embed─▶ (T,C)┘

关键点:

  1. 两种 embedding 相加token_embed 告诉模型”这是哪个词”,position_embed 告诉它”这个词在第几个位置”。两个查表向量直接相加注入模型。

    • 这里用的是可学习的绝对位置编码nn.Embedding(context_length, n_embed),每个位置一个独立向量,跟着训练)——和原始论文的固定正弦编码不同,也和第三部分 notebook 的 RoPE 完全不同。
    • pos_idxs[:T] 取前 T 个位置 id,(T,C) 通过广播加到每个 batch 上。
  2. lm_head 产出 logits。最后一个 Linear 把每个位置的 C 维表示映射到 vocab_size 维,得到”下一个 token 是词表中各个词的打分”。

  3. loss 计算。训练时传入 targets,把 (B,T,V) 摊平成 (B*T, V)targets 摊平成 (B*T,),算交叉熵。注意 F.cross_entropy 内部自带 softmax,所以模型直接输出原始 logits 即可。

生成方法 generate

def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.context_length:] # 只保留最近 context_length 个 token
logits, _ = self(idx_cond) # 前向
logits = logits[:, -1, :] # 只取最后一个位置的预测 (B, V)
probs = F.softmax(logits, dim=-1) # 转概率
idx_next = torch.multinomial(probs, num_samples=1) # 按概率采样一个 token
idx = torch.cat((idx, idx_next), dim=1) # 拼回序列,继续下一轮
return idx

这是标准的自回归生成:每次只用模型预测的最后一个位置,按概率分布采样multinomial,而非贪心取 argmax)出下一个 token,拼回去再喂进模型,循环 max_new_tokens 次。

  • idx[:, -context_length:]上下文截断:序列再长,也只喂最近 context_length 个 token(因为位置编码只有这么多、注意力窗口也只有这么大)。
  • multinomial 采样让输出有随机性和多样性;若改成 argmax 就是确定性的贪心解码。这里没有温度(temperature)、top-k、top-p 等采样控制——对比第三部分 notebook 的生成就用上了这些。

2.3 配置项详解#

所有超参集中在 config/config.py,最后打包成 default_config 字典供其它脚本 import。逐项含义:

参数默认值含义影响
VOCAB_SIZE50304词表大小由分词器决定(r50k 实际 50257,padding 到 50304)
CONTEXT_LENGTH512模型支持的最大序列长度决定位置编码表大小、注意力窗口
N_EMBED2048嵌入维度 C影响参数量最大的旋钮之一
N_HEAD16注意力头数需能整除 N_EMBED
N_BLOCKS64Transformer 层数决定模型深度
T_BATCH_SIZE32训练 batch 大小受显存限制
T_CONTEXT_LENGTH16训练时实际用的序列长度⚠️ 见下方说明
T_TRAIN_STEPS200000总训练步数
T_EVAL_STEPS1000每多少步评估一次
T_EVAL_ITERS250每次评估跑多少个 batch 取平均
T_LR5e-4初始学习率
T_LR_DECAYED5e-5衰减后学习率
T_LR_DECAY_STEP50000在第几步把学习率降到 T_LR_DECAYED单次阶梯衰减
T_OUT_PATHmodels/transformer_B.pt模型保存路径
DEVICEcuda运行设备
TRAIN_PATH / DEV_PATHpile_train.h5 / pile_dev.h5训练/验证数据

两组典型配置:

# 默认(约 21 亿参数,单卡基本跑不动,作者注释误写成 "3 Billion")
N_EMBED = 2048; N_HEAD = 16; N_BLOCKS = 64; CONTEXT_LENGTH = 512
# 推荐的 13M(先跑通流程用这个)
N_EMBED = 128; N_HEAD = 8; N_BLOCKS = 1; CONTEXT_LENGTH = 128

参数量主要花在哪? 以 13M 配置粗算:

部件参数量占比
token_embed(50304×128)≈ 6.44 M~49%
lm_head(128×50304 + bias)≈ 6.49 M~49%
1 个 Transformer Block≈ 0.18 M~1.4%
其余(位置编码、最终 LN)≈ 0.02 M<1%
合计≈ 13.1 M100%

一个反直觉但重要的结论:小模型的参数几乎全在 token embedding 和输出层(因为词表有 5 万),真正做”思考”的 Transformer 层只占 1% 多。这也解释了为什么作者说 13M 已能产出通顺文字——语言的”记忆”大量存在 embedding 里。(注:本仓库 token_embedlm_head两套独立权重,没有做权重共享;若共享可省下近一半参数。)

⚠️ 一个必须注意的细节:T_CONTEXT_LENGTH=16 vs CONTEXT_LENGTH=512 模型按 CONTEXT_LENGTH=512 构建(位置编码表有 512 行、注意力能看 512 个 token),但训练循环里 get_batch_iterator 喂的是 t_context_length=16 的短序列(见 2.4)。这意味着实际训练时每条样本只有 16 个 token,只有前 16 个位置编码会被训练到。想让模型真正学会用长上下文,需要把 T_CONTEXT_LENGTH 调大到接近 CONTEXT_LENGTH。这是套用本仓库时容易忽略的点。

2.4 训练循环(scripts/train_transformer.py#

这个脚本没有 if __name__ == '__main__' 保护,是模块级直接执行——也就是说 python scripts/train_transformer.py 一运行,import 完成的同时训练就开始了。整体分四段。

① 建模型 + 数清参数:

model = Transformer(
n_head=config['n_head'], n_embed=config['n_embed'],
context_length=config['context_length'], vocab_size=config['vocab_size'],
N_BLOCKS=config['n_blocks']
).to(config['device'])
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params:,}")

② 优化器 + 评估函数:

optimizer = torch.optim.AdamW(model.parameters(), lr=config['t_lr'])
losses = []
AVG_WINDOW = 64 # 显示用的滑动平均窗口
@torch.no_grad()
def estimate_loss(steps):
out = {}
model.eval() # 评估时切 eval 模式
for split in ['train', 'dev']:
data_path = config['train_path'] if split == 'train' else config['dev_path']
batch_iterator_eval = get_batch_iterator(data_path, config['t_batch_size'],
config['t_context_length'], device=config['device'])
losses_eval = torch.zeros(steps)
for k in range(steps):
xb, yb = next(batch_iterator_eval)
_, loss = model(xb, yb)
losses_eval[k] = loss.item()
out[split] = losses_eval[:k + 1].mean()
model.train() # 评估完切回 train 模式
return out
  • @torch.no_grad():评估不需要梯度,省显存、提速。
  • model.eval() / model.train() 切换:本模型没有 dropout/BN,切换影响不大,但这是规范写法。
  • train 和 dev 两个数据集上都评估,方便观察过拟合(train loss 持续降但 dev loss 不降 = 过拟合)。

③ 主训练循环:

batch_iterator = get_batch_iterator(config['train_path'], config['t_batch_size'],
config['t_context_length'], device=config['device'])
pbar = tqdm(range(config['t_train_steps']))
for step in pbar:
try:
xb, yb = next(batch_iterator) # 取一个 batch
_, loss = model(xb, yb) # 前向 + 算 loss
losses.append(loss.item())
pbar.set_description(f"Train loss: {np.mean(losses[-AVG_WINDOW:]):.4f}")
optimizer.zero_grad(set_to_none=True) # 清梯度
loss.backward() # 反向传播
optimizer.step() # 更新参数
if step % config['t_eval_steps'] == 0: # 周期性评估
evaluation_losses = estimate_loss(config['t_eval_iters'])
print(f"Step: {step}, Train loss: {...}, Dev loss: {...}")
if step == config['t_lr_decay_step']: # 到点降学习率
print('Decaying learning rate')
for g in optimizer.param_groups:
g['lr'] = config['t_lr_decayed']
except StopIteration:
break

标准训练四步:前向算 loss → 清零梯度 → 反向求梯度 → 优化器更新zero_grad(set_to_none=True) 把梯度置 None 而非置 0,省一点内存和计算。

学习率策略是单次硬切换:跑到第 50000 步时,把所有参数组的学习率从 5e-4 直接降到 5e-5,只有一次阶梯下降。notebook 里则用了 warmup 加余弦退火,每步都在调整。

④ 保存 checkpoint:

os.makedirs(config['t_out_path'].split('/')[0], exist_ok=True) # 确保 models/ 存在
evaluation_losses = estimate_loss(200) # 最终评估
# 防止覆盖已有文件:transformer_B.pt 存在就存成 transformer_B_1.pt …
modified_model_out_path = config['t_out_path']
save_tries = 0
while os.path.exists(modified_model_out_path):
save_tries += 1
model_out_name = os.path.splitext(config['t_out_path'])[0]
modified_model_out_path = model_out_name + f"_{save_tries}" + ".pt"
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), # 存优化器状态便于续训
'losses': losses, # 完整 loss 曲线
'train_loss': train_loss, 'dev_loss': dev_loss,
'steps': len(losses),
}, modified_model_out_path)

存的不只是模型权重,还有优化器状态、完整 loss 历史、步数——这些为续训和分析保留了所需数据(不过要注意:本脚本本身并没有实现”加载 checkpoint 继续训练”的逻辑,每次运行都是从头初始化),也方便事后画 loss 曲线。保存时如果目标文件已存在,会自动加后缀避免覆盖。

训练循环外层套了 try/except StopIteration:虽然 get_batch_iterator 是无限生成器理论上不会耗尽,但这层保护能在意外早停时仍走到保存逻辑,避免训练成果丢失。作者报告十亿参数模型最终 train loss 0.2314、dev loss 0.643。

2.5 文本生成(scripts/generate_text.py#

训练好之后用这个脚本做推理。核心函数 generate_text

def generate_text(model_path, input_text, max_new_tokens=100, device='cuda'):
checkpoint = torch.load(model_path, map_location=torch.device(device))
model = Transformer( # 用 config 里的超参重建同样结构
n_head=config['n_head'], n_embed=config['n_embed'],
context_length=config['context_length'], vocab_size=config['vocab_size'],
N_BLOCKS=config['n_blocks']
)
model.load_state_dict(checkpoint['model_state_dict']) # 灌入训练好的权重
model.eval().to(device)
enc = tiktoken.get_encoding("r50k_base") # 必须和训练时同一个分词器
start_ids = enc.encode_ordinary(input_text) # 文本 → token id
context = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0) # 加 batch 维
with torch.no_grad():
generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()
return enc.decode(generated_tokens) # token id → 文本

要点:

  • 重建结构再灌权重。PyTorch 的 state_dict 只存张量数值,不存网络结构。所以必须先用和训练时一模一样的超参实例化 Transformer,再 load_state_dict。如果 config.py 改过参数,这里就会 shape 不匹配报错——生成时的 config 必须和训练时一致。
  • 分词器必须一致。生成用的 r50k_base 要和预处理时同一个,否则 token id 对不上,输出全是乱码。
  • encode_ordinary:把文本编码成普通 token,不解析 <|endoftext|> 这类特殊 token(当普通文本处理)。
  • unsqueeze(0)(T,) 变成 (1, T) 补上 batch 维,因为模型 forward 期望 (B, T)

命令行用法:

Terminal window
python scripts/generate_text.py \
--model_path models/transformer_B.pt \
--input_text "In 1978" \
--max_new_tokens 100
参数默认含义
--model_pathcheckpoint 路径
--input_text起始提示词
--max_new_tokens100生成多少个新 token

主代码库到这里基本覆盖了下载、预处理、训练和生成整条链路。整体更接近早期 GPT(2017~2019)的实现方式。第三部分的 notebook 会在此基础上换成 LLaMA 风格组件,并加入 SFT 和 reasoning 训练。


三、notebook:现代架构 + 后训练全流程#

sft_rlhf_guide.ipynb 是一份自包含的端到端教程(74 个 cell),标题叫 Building a “Thinking” LLM from Scratch。覆盖的范围比主代码库更大:除了搭一个缩小版 LLaMA 架构,还演示了**预训练 → SFT → Reasoning(思维链)**三个阶段,最终让模型输出 <think>...</think><answer>...</answer> 这种带思考过程的结构化回答。

设计灵感来自开源项目 MiniMind(notebook 注释里多次提及 train_distill_reason.py、MiniMind 的 SFTDataset 等)。所有数据都是 notebook 内置的几条玩具样本,目的是跑通流程、理解机制,不是真出一个能用的模型。

全局配置先建立认知(这些超参决定模型规模和训练强度):

DEMO_VOCAB_SIZE = 32000 # 自训分词器的目标词表
DEMO_HIDDEN_SIZE = 1024 # 隐藏维度(相当于主代码库的 n_embed)
DEMO_NUM_LAYERS = 24 # Transformer 层数
DEMO_NUM_ATTENTION_HEADS = 16 # Query 头数
DEMO_NUM_KV_HEADS = 16 # Key/Value 头数(=Q 头数,所以这里实际是 MHA;<16 则是 GQA)
DEMO_MAX_SEQ_LEN = 1024 # 最大序列长度
SPECIAL_TOKENS_LIST = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<pad>"]
# 训练超参(三阶段各 10 epoch,学习率递减)
DEMO_BATCH_SIZE = 16
DEMO_PRETRAIN_LR = 3e-4
DEMO_SFT_LR = 1e-4
DEMO_REASONING_LR= 5e-5

3.1 自训 BPE Tokenizer#

和主代码库直接用现成的 tiktoken 不同,notebook 从零训练一个 BPE 分词器(用 HuggingFace 的 tokenizers 库)。

BPE(Byte Pair Encoding)原理:从”每个字符是一个 token”开始,反复统计语料里最高频的相邻 token 对,把它合并成一个新 token 加入词表,直到词表达到目标大小。这样高频词(如 “the”)会被合并成单个 token,生僻词则拆成若干子词——既控制了词表规模,又能用已知子词拼出没见过的新词(缓解 OOV 问题)。

训练函数:

def train_demo_tokenizer(corpus_files, vocab_size, save_path, special_tokens):
tokenizer_bpe = HFTokenizer(hf_models.BPE(unk_token="<unk>"))
# ByteLevel:先按字节切分,保证任何 Unicode 字符都能表示(不会真正 OOV)
tokenizer_bpe.pre_tokenizer = hf_pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)
tokenizer_bpe.decoder = hf_decoders.ByteLevel()
trainer = hf_trainers.BpeTrainer(
vocab_size=vocab_size,
special_tokens=special_tokens,
initial_alphabet=hf_pre_tokenizers.ByteLevel.alphabet() # 256 个字节全部作为初始字母表
)
tokenizer_bpe.train(corpus_files, trainer=trainer)
tokenizer_bpe.save(save_path) # 存成单个 .json
return tokenizer_bpe
  • ByteLevel 字节级处理是现代分词器(GPT-2、LLaMA)的标配:以 256 个字节为最小单位,任何字符(中文、emoji、罕见符号)都能由字节序列表示,理论上零 OOV
  • 训练语料只有 8 句话(notebook 内置)——纯演示,真实场景需要海量语料才能训出好词表。

四个特殊 token 的角色:

token作用
`<endoftext
`<im_start
`<im_end
<pad>填充,把不等长序列补齐到同一长度

AutoTokenizer 重新加载,是为了拿到 HuggingFace 生态的便利功能,最关键的是 apply_chat_template——把对话列表按模板拼成字符串。这里定义的是 ChatML 风格模板:

<|im_start|>user
{用户内容}<|im_end|>
<|im_start|>assistant
{助手内容}<|im_end|>

notebook 里还有一大段 try/except fallback(CallableTokenizerWrapper):万一 AutoTokenizer 加载失败,就手写一个包装类模拟 __call__ / apply_chat_template / encode / decode。这是健壮性兜底代码,理解主流程时可以跳过,知道”它保证后续 tokenizer 对象一定可用”即可——后续所有代码都通过统一的 tokenizer 变量调用分词器。

3.2 数据集与 loss mask#

notebook 定义两个 Dataset 类,分别服务预训练和对话微调。看这两个 Dataset 时主要要盯住 loss mask——它决定哪些 token 参与 loss 计算

3.2.1 DemoCorpusDataset(预训练用)#

class DemoCorpusDataset(Dataset):
def __getitem__(self, idx):
text = self.samples[idx]
full_text_with_bos = self.tokenizer.bos_token + text # 句首加 BOS
encoding = self.tokenizer(full_text_with_bos, max_length=self.max_length,
padding="max_length", truncation=True, return_tensors='pt')
input_ids = encoding.input_ids.squeeze(0) # (max_length,)
effective_loss_mask = (input_ids != self.tokenizer.pad_token_id).long() # 非 pad 处为 1
X = input_ids[:-1] # 输入
Y = input_ids[1:] # 目标(错位一格,同主代码库)
mask_for_loss_calculation = effective_loss_mask[1:] # 对齐 Y
return X, Y, mask_for_loss_calculation
  • 错位逻辑和主代码库 get_batch_iterator 完全一致(X=input_ids[:-1]Y=input_ids[1:])。
  • loss mask 这里只排除 padding:预训练要学整段文本,所以除了填充的 <pad>,每个 token 都参与 loss。mask 切 [1:] 是为了和 Y 对齐。

3.2.2 DemoChatDataset(SFT 和 Reasoning 共用)#

对话数据的处理核心是:只让模型学习 assistant 的回复,不学 user 的提问。靠 loss mask 实现:

class DemoChatDataset(Dataset):
def __getitem__(self, idx):
conversations = self.samples[idx]
# 用 chat template 把整段对话拼接 + 分词
input_ids = self.tokenizer.apply_chat_template(
conversations, tokenize=True, add_generation_prompt=False,
return_tensors="pt", max_length=self.max_length,
truncation=True, padding="max_length"
).squeeze(0)
loss_mask = torch.zeros_like(input_ids, dtype=torch.long) # 默认全 0(都不算 loss)
bos_assistant_ids = self.tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
eos_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False)
# 扫描 token 序列,把每段 "assistant\n ... <|im_end|>" 之间的 token 标成 1
i = 0
input_ids_list = input_ids.tolist()
while i < len(input_ids_list):
# 命中一段 assistant 回复的起始标记
if input_ids_list[i : i+len(bos_assistant_ids)] == bos_assistant_ids:
start_of_response = i + len(bos_assistant_ids)
# 往后找该回复的结束标记 <|im_end|>
end_marker = -1
j = start_of_response
while j < len(input_ids_list):
if input_ids_list[j : j+len(eos_ids)] == eos_ids:
end_marker = j
break
j += 1
if end_marker != -1: # 找到结束标记
loss_mask[start_of_response : end_marker + len(eos_ids)] = 1
i = end_marker + len(eos_ids) # 跳到这段之后继续扫
continue
else: # 没找到(被截断),标到末尾
loss_mask[start_of_response:] = 1
break
i += 1
loss_mask[input_ids == self.tokenizer.pad_token_id] = 0 # padding 不算
X = input_ids[:-1]; Y = input_ids[1:]
mask_for_loss_calculation = loss_mask[1:]
return X, Y, mask_for_loss_calculation

为什么 SFT 要 mask 掉用户输入? 想象训练样本是”用户问 + 助手答”。我们希望模型学的是”给定问题,怎么生成好的回答”,而不是”怎么生成用户的问题”。如果对用户那部分 token 也算 loss,模型会浪费容量去拟合用户输入的分布,甚至学会”自问自答”。所以 mask 只在 assistant 回复区间为 1,loss 只反传这部分。

  • add_generation_prompt=False:训练时 assistant 的完整回答已经在数据里了,不需要再追加生成提示符(那是推理时才加的)。
  • 扫描匹配 <|im_start|>assistant\n ... <|im_end|> 来定位回复区间——notebook 注释自己也说这是”illustrative”(示意性)的实现,MiniMind 原版用更鲁棒的 token id 直接匹配。
  • SFT 和 Reasoning 用的是同一个 DemoChatDataset,区别只在喂的数据文件不同(SFT 数据是普通问答,Reasoning 数据带 <think>/<answer> 标签)和 Reasoning 训练循环里额外的”标签加权”(见 3.6)。

3.3 现代架构 DemoLLM#

这一节是 notebook 里最值得细看的部分。DemoLLM 系列类基本就是一个缩小版 LLaMA,把主代码库的经典组件逐一换成 2023+ 的现代做法。整体组装关系:

DemoLLMForCausalLM # 模型 + lm_head + loss,继承 HF PreTrainedModel
└── DemoLLMModel # embedding + 层堆叠 + 最终 norm
└── DemoTransformerBlock ×24
├── DemoAttention # GQA + RoPE + Flash Attention + KV cache
│ └── RotaryEmbedding
├── DemoFeedForward # SwiGLU
└── DemoRMSNorm ×2 # Pre-LN

可以和主代码库逐项对应:LayerNorm → RMSNorm绝对位置编码 → RoPEReLU MLP → SwiGLU朴素 MHA → 带 GQA/Flash 的注意力纯 nn.Module → HF PreTrainedModel。下面逐个看。

3.3.0 配置类 DemoLLMConfig#

继承 HF 的 PretrainedConfig,好处是模型能用 save_pretrained / from_pretrained、能接 HF 的 generate() 等生态工具。几个关键字段:

self.head_dim = self.hidden_size // self.num_attention_heads # 1024 // 16 = 64
self.num_key_value_heads = num_key_value_heads or num_attention_heads
# 约束:Q 头数必须能被 KV 头数整除(GQA 分组的前提)
if num_attention_heads % num_key_value_heads != 0:
raise ValueError("num_attention_heads 必须能被 num_key_value_heads 整除")
hidden_act = "silu" # SwiGLU 用的激活
rms_norm_eps = 1e-5 # RMSNorm 数值稳定项
rope_theta = 10000.0 # RoPE 的频率基数

intermediate_size(FFN 中间维度)的算法值得一提:

DEMO_INTERMEDIATE_SIZE = int(DEMO_HIDDEN_SIZE * 8 / 3) # 1024 × 8/3 ≈ 2730
DEMO_INTERMEDIATE_SIZE = 32 * ((DEMO_INTERMEDIATE_SIZE + 31) // 32) # 向上对齐到 32 的倍数 = 2752
  • 为什么是 8/3 而不是 4 倍:SwiGLU 比传统 FFN 多一个 gate 投影(三个矩阵而非两个),为了让总参数量和传统”4 倍”FFN 大致相当,中间维度取 4 × 2/3 = 8/3 倍。这是 LLaMA 的标准做法。
  • 对齐到 32 的倍数:让矩阵维度对硬件友好(GPU tensor core 喜欢 8/16/32 的整数倍),算得更快。

3.3.1 RMSNorm#

RMSNorm 是 LayerNorm 的简化版,现代 LLM(LLaMA、T5)的主流选择:

class DemoRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放,无 bias
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output_dtype = x.dtype
x = x.to(torch.float32) # 在 float32 下算,保证数值稳定
output = self._norm(x)
return (output * self.weight).to(output_dtype)

公式:

RMSNorm(x)=x1di=1dxi2+ϵγ\text{RMSNorm}(x) = \frac{x}{\sqrt{\dfrac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \odot \gamma

和 LayerNorm 的区别:LayerNorm 要”减均值、除标准差、再仿射(缩放+平移)”:

LayerNorm(x)=xμσ2+ϵγ+β\text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta

RMSNorm 不减均值、没有 bias β,只用均方根(RMS)做缩放。少算一个均值、少一组 bias 参数,速度更快,实际效果通常和 LayerNorm 差不多,所以很多现代 LLM 都采用它。

  • torch.rsqrt 是”1/√x”的融合算子,比先 sqrt 再取倒数快。
  • 混合精度下的稳定技巧:先 .to(torch.float32) 再算 norm,最后转回原 dtype。因为 bf16/fp16 下平方求和容易溢出或损失精度,归一化这种对数值敏感的操作放到 float32 做更稳。

3.3.2 RoPE 旋转位置编码#

RoPE(Rotary Positional Embedding)是现代 LLM 注入位置信息的方式,替代主代码库那种”可学习的绝对位置 embedding 相加”。

核心思想:不再把位置信息到输入上,而是把每个 token 的 Query/Key 向量按其位置旋转一个角度。位置 m 处的向量旋转 m·θ,位置 n 处旋转 n·θ,两者做点积时,结果只依赖旋转角之差 (m−n)·θ——也就是相对位置。这样模型天然获得相对位置感知,还能外推到比训练时更长的序列。

数学上,把 head_dim 维向量两两配对成 d/2 个二维子向量,对位置 m、第 i 对施加旋转矩阵:

Rm,i=(cos(mθi)sin(mθi)sin(mθi)cos(mθi)),θi=100002i/dR_{m,i} = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix}, \qquad \theta_i = 10000^{-2i/d}

低维对(i 小)旋转快(高频,捕捉近距离关系),高维对旋转慢(低频,捕捉远距离关系)。

代码用复数实现旋转(很优雅,但第一次看会懵):

class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len, theta=10000.0, device=None):
super().__init__()
# θ_i = 1 / theta^(2i/d),i = 0,1,...,d/2-1
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # (d/2,)
t = torch.arange(max_seq_len).float() # (max_seq_len,)
freqs = torch.outer(t, freqs) # (max_seq_len, d/2),元素 = m·θ_i
# 用模长 1、角度 m·θ_i 构造复数 cos(m·θ_i) + i·sin(m·θ_i)
self.register_buffer("freqs_cis", torch.polar(torch.ones_like(freqs), freqs), persistent=False)
def forward(self, xq, xk, seq_len):
# xq/xk: (bsz, num_heads, seq_len, head_dim)
# 把相邻两维拼成复数:(..., head_dim) → (..., head_dim/2) 复数
xq_c = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_c = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis_pos = self.freqs_cis[:seq_len].unsqueeze(0).unsqueeze(0) # (1,1,seq_len,d/2)
# 复数乘法 = 旋转
xq_out = torch.view_as_real(xq_c * freqs_cis_pos).flatten(3)
xk_out = torch.view_as_real(xk_c * freqs_cis_pos).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

理解关键:一个复数乘以单位复数 e^{iθ} = cosθ + i·sinθ,几何上就是在复平面旋转 θ 角。代码把 head_dim 的相邻两维 (x_{2i}, x_{2i+1}) 看成复数 x_{2i} + i·x_{2i+1},乘上 freqs_cis[m] = e^{i·m·θ_i},正好实现了上面旋转矩阵的效果——用一次复数乘法代替 2×2 矩阵乘,简洁高效。

  • freqs_cis 在初始化时一次性算好所有位置×所有频率的旋转因子,存成 buffer(persistent=False 表示不写进 checkpoint,加载时重新算)。
  • RoPE 作用在 Q 和 K 上(不作用于 V),且是在注意力内部、投影之后施加(见下面 DemoAttention)。
  • 同样先转 float() 算再转回,保证混合精度下的精度。

3.3.3 注意力 DemoAttention#

这个 Attention 比主代码库的 MultiHeadAttention 多了几项现代实现里常见的东西:GQA、RoPE、Flash Attention、KV cache

class DemoAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_q_heads = config.num_attention_heads # Query 头数 16
self.num_kv_heads = config.num_key_value_heads # KV 头数(可 < Q 头数)
self.num_kv_groups = self.num_q_heads // self.num_kv_heads # 每组共享的 Q 头数
self.head_dim = config.head_dim # 64
# 注意 K/V 投影的输出维度是 num_kv_heads × head_dim(可能小于 Q)
self.q_proj = nn.Linear(hidden, num_q_heads * head_dim, bias=False)
self.k_proj = nn.Linear(hidden, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(hidden, num_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_q_heads * head_dim, hidden, bias=False) # ← 输出投影(主代码库没有)
self.rotary_emb = RotaryEmbedding(head_dim, config.max_position_embeddings, theta=config.rope_theta)
self.flash_available = hasattr(F, 'scaled_dot_product_attention') and config.flash_attn

forward 的几个关键步骤:

# 辅助方法:把 KV 头复制 n_rep 次,(bsz, num_kv_heads, seq, d) → (bsz, num_kv_heads*n_rep, seq, d)
def _repeat_kv(self, x, n_rep):
bs, num_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (x[:, :, None, :, :]
.expand(bs, num_kv_heads, n_rep, slen, head_dim)
.reshape(bs, num_kv_heads * n_rep, slen, head_dim))
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, use_cache=False):
bsz, q_len, _ = hidden_states.shape
# 1) 投影 + 拆头:(bsz, heads, q_len, head_dim)
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_q_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# 总的 K/V 序列长度(含 KV cache 里的历史 token)
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
# 2) 给 Q、K 施加 RoPE(按当前新 token 的长度)
query_states, key_states = self.rotary_emb(query_states, key_states, seq_len=q_len)
# 3) KV cache:把历史的 K/V 拼到前面(推理加速)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
current_key_value = (key_states, value_states) if use_cache else None
# 4) GQA:把 KV 头复制 num_kv_groups 次,匹配 Q 头数
key_states = self._repeat_kv(key_states, self.num_kv_groups)
value_states = self._repeat_kv(value_states, self.num_kv_groups)
# 5) 算注意力:优先 Flash,否则手动
if self.flash_available and attention_mask is None:
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
dropout_p=self.config.dropout if self.training else 0.0,
is_causal=(q_len == kv_seq_len)) # 无 KV cache 时即标准因果注意力
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# 因果掩码;causal_shift 兼容 KV cache(q_len < kv_seq_len)的情形
causal_shift = kv_seq_len - q_len
mask = torch.triu(torch.full((q_len, kv_seq_len), float("-inf"),
device=query_states.device), diagonal=1 + causal_shift)
attn_weights = attn_weights + mask[None, None, :, :]
if attention_mask is not None: # 额外的 padding 掩码(加性)
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(query_states)
attn_output = torch.matmul(attn_weights, value_states)
# 6) 合并头 + 输出投影
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
return self.o_proj(attn_output), current_key_value

四个现代特性逐一解释:

  1. GQA(Grouped Query Attention,分组查询注意力)。标准 MHA 里 Q、K、V 头数相同;GQA 让多个 Query 头共享同一组 K/V 头。比如 16 个 Q 头但只有 4 个 KV 头,则每 4 个 Q 头共用 1 组 KV。_repeat_kv 在计算时把 KV 头复制 num_kv_groups 份对齐 Q。

    • 目的:推理时 KV cache 占大量显存,GQA 把 KV 头数砍小,显存和带宽都省,几乎不掉精度。是 LLaMA-2/3、Mistral 的标配。
    • 本 notebook 默认 num_kv_heads = num_q_heads = 16,所以实际退化成标准 MHA——但代码完整支持 GQA,把 DEMO_NUM_KV_HEADS 调小即可启用。
  2. KV cache。自回归生成时,每生成一个新 token 都要重算注意力。但历史 token 的 K/V 是不变的,没必要重算。past_key_value 把历史 K/V 缓存下来,每步只算新 token 的 K/V 再 cat 上去,把生成复杂度从 O(n²) 降到接近 O(n)。这是推理提速的关键,主代码库的 generate 完全没有这个优化。

  3. Flash AttentionF.scaled_dot_product_attention 是 PyTorch 内置的融合注意力算子,用分块计算避免显式构造 (T,T) 的大注意力矩阵,大幅省显存、提速。is_causal=True 让它内部直接处理因果掩码。没有它时走下面的手动实现(matmul + triu 掩码 + softmax)。

  4. 因果掩码的细节。手动路径用 torch.triu(..., diagonal=1+causal_shift) 构造上三角 -inf 掩码。causal_shift = kv_seq_len - q_len 是为了兼容 KV cache 场景:有缓存时 q_len(新 token 数)小于 kv_seq_len(总长度),掩码的对角线要相应偏移,保证新 token 能看到所有历史、但仍看不到”更未来”。

3.3.4 前馈网络 SwiGLU#

现代 LLM 的 FFN 几乎都用 SwiGLU 替代经典的”Linear-ReLU-Linear”:

class DemoFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] # SiLU
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))

公式:

SwiGLU(x)=down(SiLU(gate(x))up(x)),SiLU(z)=zσ(z)\text{SwiGLU}(x) = \text{down}\big(\,\text{SiLU}(\text{gate}(x)) \odot \text{up}(x)\,\big), \qquad \text{SiLU}(z) = z \cdot \sigma(z)

和主代码库 ReLU-MLP 的本质区别是”门控”:经典 MLP 只有”升维→激活→降维”两个矩阵;SwiGLU 有三个矩阵——gateup 把输入都升到中间维度,gate 的输出经 SiLU 激活后作为门控信号,逐元素乘到 up 的输出上,再由 down 降回。这个乘法门让网络能动态地”放大或抑制”每个特征通道的信息,表达力比 ReLU 强,是 LLaMA/PaLM 等模型效果好的因素之一。SiLU(也叫 Swish)是平滑版的 ReLU,处处可导。

3.3.5 Transformer Block#

class DemoTransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = DemoAttention(config)
self.mlp = DemoFeedForward(config)
self.input_layernorm = DemoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DemoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, use_cache=False):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_out, present_kv = self.self_attn(hidden_states, attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value, use_cache=use_cache)
hidden_states = residual + attn_out # 残差 1
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + self.mlp(hidden_states) # 残差 2
return hidden_states, present_kv

结构和主代码库的 Block 几乎一模一样——都是 Pre-LN + 双残差。差异仅在具体部件:LayerNorm → DemoRMSNormMLP → DemoFeedForward(SwiGLU),以及注意力多返回一个 present_kv(KV cache)。也就是说,Block 的大框架还是 Pre-LN、残差、注意力和 FFN,主要变化发生在 norm、位置编码、FFN 和注意力实现这些部件上。

3.3.6 主干 DemoLLMModel#

class DemoLLMModel(PreTrainedModel):
config_class = DemoLLMConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.layers = nn.ModuleList([DemoTransformerBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = DemoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dropout = nn.Dropout(config.dropout)
def forward(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, use_cache=None, **kwargs):
batch_size, seq_length = input_ids.shape
past_len = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# 没传 position_ids 就自动生成(有 KV cache 时从 past_len 往后排)
if position_ids is None:
position_ids = torch.arange(past_len, seq_length + past_len,
dtype=torch.long, device=input_ids.device).unsqueeze(0)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.dropout(inputs_embeds)
# 把 (bsz, seq) 的 padding mask 扩成 (bsz,1,q,kv) 加性掩码;因果掩码在 attention 内部处理
_expanded_mask = None
if attention_mask is not None:
expanded = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length + past_len)
_expanded_mask = torch.zeros_like(expanded, dtype=hidden_states.dtype)
_expanded_mask.masked_fill_(expanded == 0, float("-inf"))
next_cache = [] if use_cache else None
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None else None
hidden_states, kv = layer(hidden_states, attention_mask=_expanded_mask,
position_ids=position_ids, past_key_value=past_kv, use_cache=use_cache)
if use_cache:
next_cache.append(kv)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)

最重要的差异:没有位置 embedding 层! 对比主代码库 Transformer.__init__ 里有 self.position_embed = nn.Embedding(context_length, n_embed),这里完全没有——因为位置信息已经由 RoPE 在每层注意力内部注入了。这也是 RoPE 和绝对位置 embedding 的主要区别:位置信息不在输入端相加,而是在每层注意力里作用到 Q/K 上。

其余就是标准的”嵌入 → dropout → 逐层 → 最终 RMSNorm”。继承 PreTrainedModel 让它自动获得 HF 的权重初始化、保存加载等能力。

3.3.7 带语言建模头的 DemoLLMForCausalLM#

最外层,加上 lm_head 和 loss 计算:

class DemoLLMForCausalLM(PreTrainedModel):
config_class = DemoLLMConfig
def __init__(self, config):
super().__init__(config)
self.model = DemoLLMModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# self.model.embed_tokens.weight = self.lm_head.weight # 权重共享(注释掉了,可选)
self.post_init() # HF 的统一权重初始化
def forward(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, labels=None, use_cache=None, **kwargs):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask,
position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous() # 错位:第 t 个位置预测第 t+1 个 token
shift_labels = labels[..., 1:].contiguous()
loss = nn.CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values)
# 供 HF generate() 调用:有 KV cache 时只需喂最后一个 token
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "past_key_values": past_key_values,
"attention_mask": attention_mask, "use_cache": kwargs.get("use_cache")}

三个要点:

  • 复用 HF 的 generate()。因为继承了 PreTrainedModel 并实现了 prepare_inputs_for_generation,这个模型直接支持 model.generate(..., do_sample=True, temperature=0.7, top_k=10)——温度采样、top-k、KV cache 全部现成可用。对比主代码库要手写 generate 循环,这里省了大量代码。
  • 内置 loss 用了,但训练时没用forward 在传入 labels 时会自己算交叉熵(标准的 shift 错位)。但注意:后面三个训练阶段(预训练 / SFT / Reasoning)都不传 labels,而是只取 logits 然后手动用 loss mask 算 loss——因为预训练/SFT/Reasoning 需要不同的 mask 策略,内置 loss 满足不了。
  • 权重共享(weight tying)被注释掉了。把 embed_tokens.weightlm_head.weight 绑成同一份是常见省参手段(GPT-2 就这么做),这里保留了选项但默认不启用。

模型规模:hidden=1024, 层数=24。注意 vocab 取的是分词器训练后的实际词表 DEMO_VOCAB_SIZE_FINAL——配置目标是 32000,但在 notebook 仅 8 句的玩具语料下,BPE 实际产出的词表会远小于此,embedding/lm_head 参数也相应小很多。若按名义 vocab=32000 粗算上限约 3.7 亿(≈369M)参数,其中 24 层主干约 304M 才是主体;实际跑出来总量取决于真实词表,会明显小于该上限。代码里的 print_model_summary 会在初始化后打印精确值。


notebook 后半部分按预训练、SFT、Reasoning 三段来组织。每一段都加载上一阶段保存的权重,再用更小的学习率继续训练:

预训练 (Pretrain) ──保存 demo_llm_pretrained.pth──▶ 学会基本语言规律
│ load
监督微调 (SFT) ──保存 demo_llm_sft.pth──────────▶ 学会跟随指令、对话
│ load
推理训练 (Reasoning)──保存 demo_llm_reasoning.pth─────▶ 学会输出 <think>/<answer> 结构

学习率一路递减(3e-4 → 1e-4 → 5e-5):越到后期,改动越精细,避免破坏已学到的能力。

3.4 预训练 Pretraining#

预训练让模型在原始文本上学”预测下一个 token”,掌握基本的词法、语法、常识。训练循环有几个比主代码库更现代的地方:

optimizer_pt = optim.AdamW(pt_model.parameters(), lr=DEMO_PRETRAIN_LR)
loss_fct = nn.CrossEntropyLoss(reduction='none') # 'none':返回每个 token 的 loss,便于按 mask 加权
# 混合精度上下文
autocast_ctx = nullcontext() if DEVICE.type=='cpu' else torch.amp.autocast(device_type=DEVICE.type, dtype=PTDTYPE)
scaler = torch.cuda.amp.GradScaler(enabled=(DTYPE_STR != 'float32' and DEVICE.type == 'cuda'))
for epoch in range(DEMO_PRETRAIN_EPOCHS):
for step, (X_batch, Y_batch, mask_batch) in enumerate(demo_pt_dataloader):
# 1) 余弦学习率调度,每步更新
current_lr = get_lr(current_step, total_steps, DEMO_PRETRAIN_LR)
for g in optimizer_pt.param_groups: g['lr'] = current_lr
with autocast_ctx:
outputs = pt_model(input_ids=X_batch) # 注意:不传 labels
logits = outputs.logits # (bsz, seq-1, vocab)
raw_loss = loss_fct(logits.view(-1, logits.size(-1)), Y_batch.view(-1)) # 每 token loss
# 用 mask 加权平均:只统计有效 token
masked_loss = (raw_loss * mask_batch.view(-1)).sum() / mask_batch.sum().clamp(min=1)
# 2) 混合精度反向:scaler 防止 fp16 梯度下溢
scaler.scale(masked_loss).backward()
scaler.step(optimizer_pt)
scaler.update()
optimizer_pt.zero_grad(set_to_none=True)

三个关键技术:

  1. 混合精度训练(AMP)autocast 让前向自动用 bf16/fp16 计算(省显存、提速),关键的归一化/loss 仍在 float32。GradScaler 在 fp16 下把 loss 放大再反传、更新前缩回,防止小梯度被舍入成 0(下溢)。bf16 动态范围大其实可以不缩放,但代码统一保留了 scaler。这是主代码库(纯 fp32)完全没有的工程优化。

  2. 余弦退火 + warmup 学习率(get_lr 工具函数)

def get_lr(current_step, total_steps, initial_lr, min_lr_ratio=0.1, warmup_ratio=0.01):
warmup_steps = int(warmup_ratio * total_steps)
if current_step < warmup_steps:
return initial_lr * (current_step / warmup_steps) # 线性 warmup
progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) # 余弦衰减
return min_lr + coeff * (initial_lr - min_lr)

前 1% 步线性升温(warmup,避免一开始大学习率震坏模型),之后按余弦曲线平滑降到 min_lr(= 初始的 10%)。比主代码库”第 50000 步一刀切”平滑得多,是现代训练标配。

  1. mask 加权 lossCrossEntropyLoss(reduction='none') 返回每个位置的 loss,乘以 mask 后求和、再除以 mask 内 token 数——(raw_loss * mask).sum() / mask.sum()clamp(min=1) 防止除零。预训练阶段 mask 只排除 padding,所以等价于”在所有真实 token 上算平均 loss”。这套 mask 加权机制是三个阶段通用的,区别只在 mask 怎么来

训练完保存 demo_llm_pretrained.pth,并做一次快速生成测试(greedy 解码)。由于数据只有 7 句话,输出必然是复读或胡言——重点是流程跑通。

3.5 SFT 监督微调#

SFT(Supervised Fine-Tuning)让预训练模型学会跟随指令、进行对话。代码和预训练几乎一样,三处不同

  1. 加载预训练权重sft_model.load_state_dict(torch.load(final_pretrained_model_path))——不是从随机初始化开始,而是在预训练的结果上继续训。
  2. 换数据集:用 DemoChatDataset + SFT 对话数据,于是 mask_batch 变成”只在 assistant 回复处为 1”。同样的 (raw_loss * mask).sum() / mask.sum() 公式,效果就变成只在助手回复上算 loss——模型只学怎么回答,不学复述用户问题。
  3. 更小的学习率1e-4(预训练是 3e-4)。

从代码看,SFT 和预训练共用同一套训练循环,主要差别就在数据格式和 loss mask。保存为 demo_llm_sft.pth,再用 do_sample=True 采样测试问答。

3.6 Reasoning 推理训练#

Reasoning 这一节在 SFT 基础上加了特殊标签加权,让模型学会在回答前先输出显式思考过程,格式是 <think>思考...</think><answer>答案...</answer>

# 取出四个标签的"首 token id"
think_start_id = tokenizer.encode('<think>', add_special_tokens=False)[0]
think_end_id = tokenizer.encode('</think>', add_special_tokens=False)[0]
answer_start_id = tokenizer.encode('<answer>', add_special_tokens=False)[0]
answer_end_id = tokenizer.encode('</answer>', add_special_tokens=False)[0]
special_tag_first_token_ids = torch.tensor([
think_start_id, think_end_id, answer_start_id, answer_end_id
], device=DEVICE).unique()
REASONING_TAG_LOSS_WEIGHT = 5.0 # 标签 token 的 loss 权重放大 5 倍
# 训练循环里:
raw_loss = loss_fct(logits.view(-1, vocab), Y_batch.view(-1)) # 每 token loss
effective_loss_weights = sft_style_mask.view(-1).float().clone() # 先用 SFT 那套 mask(assistant 区间)
# 找出 target 里属于特殊标签的位置
is_special = torch.isin(Y_batch.view(-1), special_tag_first_token_ids)
# 在"是 assistant token 且是特殊标签"的位置,权重 ×5
apply_extra = is_special & (sft_style_mask.view(-1) == 1)
effective_loss_weights[apply_extra] *= REASONING_TAG_LOSS_WEIGHT
weighted_loss = (raw_loss * effective_loss_weights).sum() / sft_style_mask.sum().clamp(min=1)

为什么要给标签加权? <think></think><answer></answer> 这些结构标签在文本里出现频率很低,但它们是”思考格式”的骨架。如果所有 token 一视同仁,模型很可能学不牢这些低频却关键的标签,生成时格式就乱了。给它们 5 倍 loss 权重,等价于提高这些位置的梯度权重,让模型更优先拟合这几个格式 token,从而更稳定地生成正确的 <think>...</think><answer>...</answer> 结构。

注意分母仍用 sft_style_mask.sum()(原始 mask 的 token 数)归一化,这样 loss 量级和 SFT 阶段大致可比,只是标签位置的梯度被放大了。学习率进一步降到 5e-5。保存为 demo_llm_reasoning.pth

可以把这看成一种简单的格式约束:仍然是监督训练,只是把结构标签的 loss 权重调高。真实的推理模型(如 DeepSeek-R1)靠强化学习(RL)来激发思考能力,notebook 用的是有监督 + 标签加权来模仿思考的”形”,注释也坦言小模型的 reasoning 更多是模式模仿而非真正推理。

3.7 推理:让模型”思考”#

最后加载 reasoning 模型,用 get_structured_response 跑端到端推理:

def get_structured_response(model, user_query, max_new_toks=DEMO_MAX_SEQ_LEN - 10, temp=0.7, tk=10):
chat_history = [{"role": "user", "content": user_query}]
# add_generation_prompt=True:在末尾补 "<|im_start|>assistant\n" 提示模型轮到它说
prompt_text = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(DEVICE)
generated_ids = model.generate(
input_ids, max_new_tokens=max_new_toks,
do_sample=True, temperature=temp, top_k=tk, # 温度采样 + top-k
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
)
# 只取新生成的部分(去掉 prompt)
response = tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
# 用字符串查找解析出 <think> 和 <answer> 的内容(容错:标签不全时回退到整段)
think_part, answer_part = "Not found", response
if "<think>" in response and "</think>" in response:
think_part = response[response.find("<think>") + len("<think>") : response.find("</think>")].strip()
if "<answer>" in response and "</answer>" in response:
answer_part = response[response.find("<answer>") + len("<answer>") : response.find("</answer>")].strip()
return think_part, answer_part

对比主代码库手写的 generate,这里接入 HF 之后少写了不少生成相关代码:

  • add_generation_prompt=True:推理时数据里只有 user 输入、没有 assistant 回答,需要手动补上 <|im_start|>assistant\n 来提示模型开始作答。训练时样本里已经有完整回答,所以不用加(False)。
  • 采样控制:直接用 HF generatetemperature=0.7(温度,控制随机性,<1 更确定)、top_k=10(只从概率最高的 10 个 token 里采样,避免选到长尾噪声)。主代码库的 generate 这些都得手写。
  • 结构解析:生成后用普通的字符串 find 切出 <think><answer> 之间的内容。模型若没学好、标签不完整,代码有多重 fallback(找不到 think 就把全文当 answer)。

notebook 到这里把 tokenizer 训练、模型搭建、三段训练和结构化生成都跑了一遍。虽然数据只是几条玩具样本,但主要流程都能对应到真实 LLM 训练里的环节。


四、横向对比与速查#

4.1 经典 vs 现代架构对照表#

两套实现逐维度并排来看:

维度主代码库(经典 GPT 风)notebook(现代 LLaMA 风)为什么换
归一化LayerNorm(减均值、除标准差、含 bias)RMSNorm(只除均方根、无 bias)更快、参数更少、效果相当
归一化位置Pre-LNPre-LN都用 Pre-LN(训练稳定)
位置编码可学习的绝对位置 embedding(相加)RoPE 旋转位置编码(每层注意力内旋转)编码相对位置、可外推、省参数
激活/FFNReLU + 2 矩阵、4× 升维SwiGLU + 3 矩阵门控、8/3× 升维门控表达力更强
注意力朴素 MHA(ModuleList 循环每个头)融合 QKV + 支持 GQA高效、省 KV cache 显存
输出投影 W_O无(直接拼接输出)o_proj标准 Transformer 设计
注意力加速Flash Attention(scaled_dot_product_attention省显存、提速
KV cache无(生成时每步重算全部)有(past_key_value生成提速 O(n²)→~O(n)
分词器tiktoken r50k_base(现成,词表 50304)自训 BPE(tokenizers,词表 ~32000)教学:从零训练
框架torch.nn.ModuleHF PreTrainedModel / PretrainedConfig复用 generate/save/load
生成采样手写 multinomial(无温度/top-k)HF generate(温度、top-k、KV cache)现成、可控
训练精度fp32bf16/fp16 混合精度(autocast + GradScaler)省显存、提速
学习率调度单次阶梯硬切换(50000 步降一次)余弦退火 + 线性 warmup(每步更新)平滑、收敛更好
训练阶段仅预训练预训练 → SFT → Reasoning 三段完整后训练 pipeline
loss mask无(所有 token 都算)有(SFT 只算 assistant、Reasoning 加权标签)对齐任务目标
数据The Pile(825GB 子集,真实语料)notebook 内置几条玩具样本教学演示
数据存储HDF5 一维 token 流jsonl + PyTorch Dataset各自场景

默认配置上:主代码库默认 ~21 亿参数(实跑需调小到 13M);notebook 主体是约 304M 的 24 层主干(按名义 vocab=32000 算总量上限约 3.7 亿,实际词表更小、总量也更小)。

4.2 关键概念速查#

读这个仓库会反复碰到的概念,一句话定义:

  • 自回归语言建模:模型每一步预测序列的下一个 token,用”输入”和”错位一格的目标”训练(X=ids[:-1]Y=ids[1:])。两套实现都这么做。
  • 因果掩码(Causal Mask):用下三角/上三角 -inf 让位置 i 只能注意到 ≤ i 的 token,保证预测下一个词时看不到未来。decoder-only 模型的根本约束。
  • <|endoftext|>:标记文档/序列结束的特殊 token,分隔不相关文本、兼当生成停止信号;常同时用作 BOS/EOS。
  • BOS / EOS / PAD:序列开始 / 结束 / 填充。PAD 把不等长序列补齐到同长度,且必须用 mask 排除在 loss 之外。
  • loss mask:一个 0/1(或加权)张量,决定哪些 token 参与 loss。预训练排除 padding;SFT 只保留 assistant 回复;Reasoning 进一步给结构标签加权。
  • 残差连接x = x + sublayer(x),让梯度直达浅层,是深层网络能训起来的关键。
  • Pre-LN:在进子层之前做归一化(x + attn(norm(x))),比 Post-LN 训练更稳。两套都用。
  • RoPE:把 Q/K 向量按位置旋转角度来注入位置信息,点积时自动体现相对位置;现代 LLM 主流位置编码。
  • GQA:多个 Query 头共享一组 K/V 头,省 KV cache 显存;num_kv_heads = num_q_heads 时退化为标准 MHA。
  • KV cache:生成时缓存历史 token 的 K/V,避免每步重算,大幅提速。
  • 混合精度(AMP):前向用 bf16/fp16 算(省显存提速),敏感部分用 fp32;GradScaler 防 fp16 梯度下溢。
  • 温度 / top-k 采样:温度 <1 让分布更尖锐(更确定),top-k 只从概率最高的 k 个候选里采样(去长尾噪声)。
  • 三段式训练:预训练(学语言)→ SFT(学指令/对话)→ Reasoning(学思考格式),逐步加载权重、学习率递减。

4.3 常见坑与注意事项#

实际上手这个仓库容易踩的点,集中列出:

  1. PYTHONPATH 必须包含项目根目录scripts/*.py 里用 from config.config import ...from src.models...,不设 export PYTHONPATH="$PYTHONPATH:." 会直接 ModuleNotFoundError

  2. 默认 config 跑不动config/config.py 默认是 N_EMBED=2048, N_BLOCKS=64(约 21 亿参数),单卡必然 OOM。先跑通务必改成 13M 配置(N_EMBED=128, N_HEAD=8, N_BLOCKS=1, CONTEXT_LENGTH=128)。代码注释里的 “3 Billion” 是笔误(实际约 2.1B)。

  3. 训练序列长度被 T_CONTEXT_LENGTH=16 卡住。模型按 CONTEXT_LENGTH=512 构建,但训练循环喂的是长度 16 的序列,等于只训练了前 16 个位置。要让模型用上长上下文,得把 T_CONTEXT_LENGTH 调大。

  4. train_transformer.py 没有 if __name__ == '__main__' 保护。它是模块级脚本,import 它就会立刻开始训练。不要在别处随意 import。

  5. max_data=1000 只是试跑设置data_preprocess.py 默认每个文件只处理前 1000 行 JSON。正经训练要调大或处理全部,否则数据量远远不够。

  6. 分词器和 config 必须前后一致。生成时用的分词器要和预处理时同一个(都 r50k_base),模型超参要和训练时完全相同——否则 token 对不上 / shape 不匹配。

  7. notebook 需要额外依赖sft_rlhf_guide.ipynb 用到 transformerstokenizers,它们不在 requirements.txt 里,需自行 pip install transformers tokenizers

  8. 主代码库的有意简化。相比标准/现代 Transformer,主代码库省略了多头注意力的输出投影 W_O、没有 attention/residual dropout、没有权重共享——都是为教学简化,不是 bug,但移植到生产前要补齐。


小结#

这个仓库适合对照着看两类实现:

  • 主代码库偏早期 GPT 写法——LayerNorm、绝对位置编码、ReLU、朴素多头注意力,配上 The Pile 真实数据和命令行训练流程。
  • notebook 偏 LLaMA 风格——RMSNorm、RoPE、SwiGLU、GQA、Flash Attention、KV cache,再加上预训练 / SFT / Reasoning 三段后训练和思考格式对齐。

两边都手写了主要模块,读起来比较容易看清哪些是 Transformer 的基本骨架(残差、Pre-LN、注意力 + FFN),哪些是后来逐步换上的工程组件,以及预训练之后的 SFT / reasoning 是怎么在同一套训练循环上叠加出来的。

进一步可以探索的方向(notebook 结尾也提到):放大数据和模型规模、引入 MoE、更彻底地用 GQA、分布式训练(DDP/FSDP)、梯度检查点,以及用 RLHF/DPO 做更精细的偏好对齐。

train-llm-from-scratch 完整笔记
https://github.com/FareedKhan-dev/train-llm-from-scratch
作者
xwysyy
发布于
2026-06-03
许可协议
CC BY-NC-SA 4.0
© 2026 xwysyy. All Rights Reserved.
Powered by Astro & Firefly

文章目录