基于pytorch,从底层实现transformer

1
2
3
4
import torch
from torch import nn
import math
import re
1
device=torch.device('mps')

读取数据集,这里使用”time machine”

1
2
3
4
5
def read_time_machine():
# 这里假设你已经有了 txt 文件,或者直接从网络下载
with open('timemachine.txt', 'r') as f:
lines = f.readlines()
return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

下面开始实现模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#多头注意力
class MultiHeadAttention(nn.Module):
def __init__(self,num_hiddens,num_heads,dropout,bias=False,**kwargs):
super().__init__()
self.num_heads=num_heads
self.Wq=nn.LazyLinear(num_hiddens,bias=bias)
self.Wk=nn.LazyLinear(num_hiddens,bias=bias)
self.Wv=nn.LazyLinear(num_hiddens,bias=bias)
self.Wo=nn.LazyLinear(num_hiddens,bias=bias)
self.dropout=dropout

def to_multihead_qkv(self,X):
#X:(batch_size,steps,d)->(batch_size*num_heads,steps,d/num_heads)
X=X.reshape(X.shape[0],X.shape[1],self.num_heads,-1)#(bs,st,h,d/h)
X=X.permute(0,2,1,3).contiguous()#(bs,h,st,d/h)
X=X.reshape(-1,X.shape[2],X.shape[3])
return X

def to_single_output(self,X):
#X:(batch_size*num_heads,steps,d/num_heads)->(batch_size,steps,d)
X=X.reshape(-1,self.num_heads,X.shape[1],X.shape[2])#(bs,h,st,d/h)
X=X.permute(0,2,1,3).contiguous()#(bs,st,h,d/h)
X=X.reshape(X.shape[0],X.shape[1],-1)
return X

def forward(self,Q,K,V,requires_mask):
#为了计算高效,先乘W后分头
Q=self.to_multihead_qkv(self.Wq(Q))
K=self.to_multihead_qkv(self.Wk(K))
V=self.to_multihead_qkv(self.Wv(V))

d=Q.shape[-1]
scores=torch.bmm(Q,K.transpose(1,2))/math.sqrt(d)#scores:(batch_size*num_heads,queries,steps)
if requires_mask:
if scores.shape[1]==scores.shape[2]:
mask=torch.triu(torch.ones((scores.shape[1],scores.shape[2]),device=scores.device),diagonal=1).bool()
scores.masked_fill_(mask,-1e9)
weights=nn.functional.softmax(scores,dim=-1)
weights=nn.functional.dropout(weights,p=self.dropout,training=self.training)
Y=torch.bmm(weights,V)
Y_concat=self.to_single_output(Y)
return self.Wo(Y_concat)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#位置编码
class PositionalEncoding(nn.Module):
def __init__(self,num_hiddens,dropout,maxlen=1000):
super().__init__()
self.dropout=nn.Dropout(p=dropout)

self.P=torch.zeros((1,maxlen,num_hiddens))
position=torch.arange(maxlen,dtype=torch.float32).reshape(-1,1)
div_term = torch.exp(torch.arange(0, num_hiddens, 2, dtype=torch.float32) * -(math.log(10000.0) / num_hiddens))
self.P[:, :, 0::2] = torch.sin(position * div_term)
self.P[:, :, 1::2] = torch.cos(position * div_term)

def forward(self,X):
X=X+self.P[:,:X.shape[1],:].to(X.device)
return self.dropout(X)
1
2
3
4
5
6
7
8
9
10
#基于位置的前馈网络
class FFN(nn.Module):
def __init__(self,num_input,num_hiddens,num_output,**kwargs):
super().__init__(**kwargs)
self.l1=nn.Linear(num_input,num_hiddens)
self.relu=nn.ReLU()
self.l2=nn.Linear(num_hiddens,num_output)

def forward(self,X):
return self.l2(self.relu(self.l1(X)))
1
2
3
4
5
6
7
8
9
#残差连接&层规范化
class AddLNorm(nn.Module):
def __init__(self,LN_shape,dropout,**kwargs):
super().__init__(**kwargs)
self.dropout=nn.Dropout(dropout)
self.ln=nn.LayerNorm(LN_shape)

def forward(self,Y,X):
return self.ln(X+self.dropout(Y))
1
2
3
4
5
6
7
8
9
10
11
12
13
#编码器块
class EncoderBlock(nn.Module):
def __init__(self,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,bias=False,**kwargs):
super().__init__(**kwargs)
self.attention=MultiHeadAttention(num_hiddens,num_heads,dropout,bias)
self.addnorm1=AddLNorm(LN_shape,dropout)
self.ffn=FFN(num_hiddens,FFN_num_hiddens,num_hiddens)
self.addnorm2=AddLNorm(LN_shape,dropout)

def forward(self,X):
Y=self.addnorm1(self.attention(X,X,X,False),X)
return self.addnorm2(self.ffn(Y),Y)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#编码器
class Encoder(nn.Module):
def __init__(self,vocab_size,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,num_layers,bias=False,**kwargs):
super().__init__(**kwargs)
self.num_hiddens=num_hiddens
self.embedding=nn.Embedding(vocab_size,num_hiddens)
self.pe=PositionalEncoding(num_hiddens,dropout)
self.blocks=nn.Sequential()
for i in range(num_layers):
self.blocks.add_module(f'encoder_block_{i}',EncoderBlock(num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,bias=bias))

def forward(self,X):
X=self.pe(self.embedding(X)*math.sqrt(self.num_hiddens))
X=self.blocks(X)
return X
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#解码器块
class DecoderBlock(nn.Module):
def __init__(self,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,idx,bias=False,**kwargs):
super().__init__(**kwargs)
self.idx=idx
self.selfattention=MultiHeadAttention(num_hiddens,num_heads,dropout,bias)
self.addnorm1=AddLNorm(LN_shape,dropout)
self.crossattention=MultiHeadAttention(num_hiddens,num_heads,dropout,bias)
self.addnorm2=AddLNorm(LN_shape,dropout)
self.ffn=FFN(num_hiddens,FFN_num_hiddens,num_hiddens)
self.addnorm3=AddLNorm(LN_shape,dropout)

def forward(self,X,state):
enc_output,kv_cache=state[0],state[1]
if kv_cache[self.idx] is None:#train or first token in prediction
k_and_v=X
else:#predict
k_and_v=torch.cat((kv_cache[self.idx],X),dim=1)
kv_cache[self.idx]=k_and_v
X2=self.selfattention(X,k_and_v,k_and_v,True)
Y=self.addnorm1(X2,X)
Y2=self.crossattention(Y,enc_output,enc_output,False)
Z=self.addnorm2(Y2,Y)
return self.addnorm3(self.ffn(Z),Z),state

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#解码器
class Decoder(nn.Module):
def __init__(self,vocab_size,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,num_layers,bias=False,**kwargs):
super().__init__(**kwargs)
self.num_hiddens=num_hiddens
self.num_layers=num_layers
self.embedding=nn.Embedding(vocab_size,num_hiddens)
self.pe=PositionalEncoding(num_hiddens,dropout)
self.blocks=nn.Sequential()
for i in range(num_layers):
self.blocks.add_module(f'decoder_block_{i}',DecoderBlock(num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,idx=i,bias=bias))
self.dense=nn.Linear(num_hiddens,vocab_size)

def init_state(self,enc_output):
return [enc_output,[None]*self.num_layers]

def forward(self,X,state):
X=self.pe(self.embedding(X)*math.sqrt(self.num_hiddens))
for block in self.blocks:
X,state=block(X,state)
return self.dense(X),state
1
2
3
4
5
6
7
8
9
10
11
12
#编码器-解码器模型
class MyTransformerNet(nn.Module):
def __init__(self,encoder,decoder,**kwargs):
super().__init__(**kwargs)
self.encoder=encoder
self.decoder=decoder

def forward(self,enc_X,dec_X):
enc_outputs=self.encoder(enc_X)
dec_state=self.decoder.init_state(enc_outputs)
Y_hat,dec_state=self.decoder(dec_X,dec_state)
return Y_hat,dec_state

模型部分结束,下面是数据预处理,训练和预测部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import collections

# ====== 1. 超参数 ======
batch_size = 128
num_steps = 40 # 严格固定的序列长度 (无填充)
lr = 0.001 # 使用 Adam 优化器,0.001 是黄金开局
epochs = 20

# ====== 2. 纯净版词表 (只需 <bos> 和 *) ======
class RestoreVocab:
def __init__(self, tokens):
# 只需要 <bos> 引导解码,不需要 <pad>,不需要 <eos>
self.idx_to_token = ['<bos>', '*']
unique_tokens = list(set(tokens))
if '*' in unique_tokens: unique_tokens.remove('*')
self.idx_to_token.extend(unique_tokens)
self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}

def __len__(self):
return len(self.idx_to_token)

def __getitem__(self, key):
if not isinstance(key, (list, tuple)):
return self.token_to_idx.get(key, 0)
return [self.__getitem__(token) for token in key]

# ====== 3. 数据集构建 ======
lines = read_time_machine()
corpus_chars = [char for line in lines for char in line]
vocab = RestoreVocab(corpus_chars)

class VowelRestoreDataset(Dataset):
def __init__(self, corpus_chars, num_steps, vocab):
self.corpus = corpus_chars
self.num_steps = num_steps
self.vocab = vocab
self.vowels = set('aeiou')
self.num_samples = len(corpus_chars) - num_steps

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
# 截取绝对定长的明文片段
chunk = self.corpus[idx : idx + self.num_steps]

# Encoder 输入: 将元音替换为 * (长度不变)
masked_chunk = ['*' if c in self.vowels else c for c in chunk]
enc_X = self.vocab[masked_chunk]

# Decoder 输入: <bos> 开头,并舍弃原片段最后一个字符,确保总长度仍为 num_steps
dec_X = [self.vocab['<bos>']] + self.vocab[chunk[:-1]]

# Label: 就是明文片段本身
Y = self.vocab[chunk]

return (torch.tensor(enc_X, dtype=torch.long),
torch.tensor(dec_X, dtype=torch.long),
torch.tensor(Y, dtype=torch.long))

# 丢弃最后凑不齐 batch_size 的数据
train_loader = DataLoader(VowelRestoreDataset(corpus_chars, num_steps, vocab),
batch_size=batch_size, shuffle=True, drop_last=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 实例化你的 Transformer (使用你上方手搓的代码)
num_hiddens, num_heads, num_layers, dropout = 256, 8, 3, 0.1
LN_shape = [num_hiddens]
FFN_num_hiddens = num_hiddens * 4

encoder = Encoder(len(vocab), num_hiddens, LN_shape, FFN_num_hiddens, num_heads, dropout, num_layers)
decoder = Decoder(len(vocab), num_hiddens, LN_shape, FFN_num_hiddens, num_heads, dropout, num_layers)
net = MyTransformerNet(encoder, decoder).to(device)

# 优化器与普通的交叉熵损失 (不需要屏蔽任何东西)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

print("🚀 开始进行 '去元音复原' 训练...")
for e in range(epochs):
net.train()
total_loss = 0
for enc_X, dec_X, Y in train_loader:
enc_X, dec_X, Y = enc_X.to(device), dec_X.to(device), Y.to(device)

# 前向传播
Y_hat, _ = net(enc_X, dec_X)

# 计算 Loss (直接展平)
l = loss_fn(Y_hat.reshape(-1, Y_hat.shape[-1]), Y.reshape(-1))

optimizer.zero_grad()
l.backward()
nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) # 裁剪防止梯度爆炸
optimizer.step()

total_loss += l.item()

print(f'Epoch {e + 1:02d}, Loss: {total_loss / len(train_loader):.4f}')
🚀 开始在 M5 芯片上进行 '去元音复原' 训练...
Epoch 01, Loss: 1.4954
Epoch 02, Loss: 0.9292
Epoch 03, Loss: 0.5886
Epoch 04, Loss: 0.3895
Epoch 05, Loss: 0.2511
Epoch 06, Loss: 0.1716
Epoch 07, Loss: 0.1105
Epoch 08, Loss: 0.0815
Epoch 09, Loss: 0.0703
Epoch 10, Loss: 0.0623
Epoch 11, Loss: 0.0574
Epoch 12, Loss: 0.0531
Epoch 13, Loss: 0.0500
Epoch 14, Loss: 0.0473
Epoch 15, Loss: 0.0449
Epoch 16, Loss: 0.0423
Epoch 17, Loss: 0.0411
Epoch 18, Loss: 0.0393
Epoch 19, Loss: 0.0383
Epoch 20, Loss: 0.0370
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def predict_restore(net, masked_text, vocab, device):
net.eval()
seq_len = len(masked_text)

# 编码密文
enc_tokens = vocab[list(masked_text)]
enc_X = torch.tensor([enc_tokens], dtype=torch.long, device=device)

with torch.no_grad():
enc_outputs = net.encoder(enc_X)
dec_state = net.decoder.init_state(enc_outputs)

# 放入解码引导符 <bos>
dec_X = torch.tensor([[vocab['<bos>']]], dtype=torch.long, device=device)
outputs = []

# 严格执行与输入等长的预测步骤
for _ in range(seq_len):
Y_hat, dec_state = net.decoder(dec_X, dec_state)
next_token_idx = Y_hat[:, 0, :].argmax(dim=-1).item()

outputs.append(next_token_idx)
# 喂给下一步
dec_X = torch.tensor([[next_token_idx]], dtype=torch.long, device=device)

return "".join([vocab.idx_to_token[i] for i in outputs])

# ================= 见证奇迹的时刻 =================
# 跑完 20 轮之后运行这行代码:
test_cases = [
"th* t*m* tr*v*ll*r s*t d*wn", # the time traveller sat down
"h* m*d* * w*nd*rf*l d*sc*v*ry", # he made a wonderful discovery
"*t w*s * d*rk *nd st*rmy n*ght" # it was a dark and stormy night
]

print("\n==== 破译结果 ====")
for case in test_cases:
print(f"输入 (密文) : {case}")
print(f"输出 (还原) : {predict_restore(net, case, vocab, device)}\n")
==== 破译结果 ====
输入 (密文) : th* t*m* tr*v*ll*r s*t d*wn
输出 (还原) : the time trave traveller sa

输入 (密文) : h* m*d* * w*nd*rf*l d*sc*v*ry
输出 (还原) : he made i wonderful discovery

输入 (密文) : *t w*s * d*rk *nd st*rmy n*ght
输出 (还原) : it was a a dark andark and sto

注:数据处理,训练,预测部分为Gemini生成,任务为填补句子中的元音(用星号表示)

后记

虽然Transformer在如今的深度学习领域是如同“Hello World”一般的存在,但经过今天的尝试,发现如果不调用torch中的组件,纯手写一个Transformer模型还是具有相当的难度的,代码中的各种细节,tensor形状,各类方法的参数,不同类之间的耦合,都需要大量的思考与梳理。但是,另一方面,这种“造轮子”的过程确实大大加深了我对于模型的理解。搞清楚整个模型在训练和预测时数据流动的通道,流动方式,经过了哪些类/函数,形状发生了怎样的改变,对于形成对模型的深刻理解是大有裨益的。甚至可以说,了解模型原理只占复现模型的30%,剩下的70%都藏在实现细节与设计哲学中。
这个Transformer模型的参数量大约只有5M,也就是0.005B,但在我的设备上训练仍然需要3min/epoch,看来深度学习真是一个极度耗费资源和算力的领域,如果说ML是“术业有专攻”,“精准击破”,那DL就是“一招鲜吃遍天”,是“力大砖飞”。