1 2 3 4 import torchfrom torch import nnimport mathimport re
1 device=torch.device('mps' )
读取数据集,这里使用”time machine”
1 2 3 4 5 def read_time_machine (): with open ('timemachine.txt' , 'r' ) as f: lines = f.readlines() return [re.sub('[^A-Za-z]+' , ' ' , line).strip().lower() for line in lines]
下面开始实现模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class MultiHeadAttention (nn.Module): def __init__ (self,num_hiddens,num_heads,dropout,bias=False ,**kwargs ): super ().__init__() self .num_heads=num_heads self .Wq=nn.LazyLinear(num_hiddens,bias=bias) self .Wk=nn.LazyLinear(num_hiddens,bias=bias) self .Wv=nn.LazyLinear(num_hiddens,bias=bias) self .Wo=nn.LazyLinear(num_hiddens,bias=bias) self .dropout=dropout def to_multihead_qkv (self,X ): X=X.reshape(X.shape[0 ],X.shape[1 ],self .num_heads,-1 ) X=X.permute(0 ,2 ,1 ,3 ).contiguous() X=X.reshape(-1 ,X.shape[2 ],X.shape[3 ]) return X def to_single_output (self,X ): X=X.reshape(-1 ,self .num_heads,X.shape[1 ],X.shape[2 ]) X=X.permute(0 ,2 ,1 ,3 ).contiguous() X=X.reshape(X.shape[0 ],X.shape[1 ],-1 ) return X def forward (self,Q,K,V,requires_mask ): Q=self .to_multihead_qkv(self .Wq(Q)) K=self .to_multihead_qkv(self .Wk(K)) V=self .to_multihead_qkv(self .Wv(V)) d=Q.shape[-1 ] scores=torch.bmm(Q,K.transpose(1 ,2 ))/math.sqrt(d) if requires_mask: if scores.shape[1 ]==scores.shape[2 ]: mask=torch.triu(torch.ones((scores.shape[1 ],scores.shape[2 ]),device=scores.device),diagonal=1 ).bool () scores.masked_fill_(mask,-1e9 ) weights=nn.functional.softmax(scores,dim=-1 ) weights=nn.functional.dropout(weights,p=self .dropout,training=self .training) Y=torch.bmm(weights,V) Y_concat=self .to_single_output(Y) return self .Wo(Y_concat)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class PositionalEncoding (nn.Module): def __init__ (self,num_hiddens,dropout,maxlen=1000 ): super ().__init__() self .dropout=nn.Dropout(p=dropout) self .P=torch.zeros((1 ,maxlen,num_hiddens)) position=torch.arange(maxlen,dtype=torch.float32).reshape(-1 ,1 ) div_term = torch.exp(torch.arange(0 , num_hiddens, 2 , dtype=torch.float32) * -(math.log(10000.0 ) / num_hiddens)) self .P[:, :, 0 ::2 ] = torch.sin(position * div_term) self .P[:, :, 1 ::2 ] = torch.cos(position * div_term) def forward (self,X ): X=X+self .P[:,:X.shape[1 ],:].to(X.device) return self .dropout(X)
1 2 3 4 5 6 7 8 9 10 class FFN (nn.Module): def __init__ (self,num_input,num_hiddens,num_output,**kwargs ): super ().__init__(**kwargs) self .l1=nn.Linear(num_input,num_hiddens) self .relu=nn.ReLU() self .l2=nn.Linear(num_hiddens,num_output) def forward (self,X ): return self .l2(self .relu(self .l1(X)))
1 2 3 4 5 6 7 8 9 class AddLNorm (nn.Module): def __init__ (self,LN_shape,dropout,**kwargs ): super ().__init__(**kwargs) self .dropout=nn.Dropout(dropout) self .ln=nn.LayerNorm(LN_shape) def forward (self,Y,X ): return self .ln(X+self .dropout(Y))
1 2 3 4 5 6 7 8 9 10 11 12 13 class EncoderBlock (nn.Module): def __init__ (self,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,bias=False ,**kwargs ): super ().__init__(**kwargs) self .attention=MultiHeadAttention(num_hiddens,num_heads,dropout,bias) self .addnorm1=AddLNorm(LN_shape,dropout) self .ffn=FFN(num_hiddens,FFN_num_hiddens,num_hiddens) self .addnorm2=AddLNorm(LN_shape,dropout) def forward (self,X ): Y=self .addnorm1(self .attention(X,X,X,False ),X) return self .addnorm2(self .ffn(Y),Y)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class Encoder (nn.Module): def __init__ (self,vocab_size,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,num_layers,bias=False ,**kwargs ): super ().__init__(**kwargs) self .num_hiddens=num_hiddens self .embedding=nn.Embedding(vocab_size,num_hiddens) self .pe=PositionalEncoding(num_hiddens,dropout) self .blocks=nn.Sequential() for i in range (num_layers): self .blocks.add_module(f'encoder_block_{i} ' ,EncoderBlock(num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,bias=bias)) def forward (self,X ): X=self .pe(self .embedding(X)*math.sqrt(self .num_hiddens)) X=self .blocks(X) return X
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class DecoderBlock (nn.Module): def __init__ (self,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,idx,bias=False ,**kwargs ): super ().__init__(**kwargs) self .idx=idx self .selfattention=MultiHeadAttention(num_hiddens,num_heads,dropout,bias) self .addnorm1=AddLNorm(LN_shape,dropout) self .crossattention=MultiHeadAttention(num_hiddens,num_heads,dropout,bias) self .addnorm2=AddLNorm(LN_shape,dropout) self .ffn=FFN(num_hiddens,FFN_num_hiddens,num_hiddens) self .addnorm3=AddLNorm(LN_shape,dropout) def forward (self,X,state ): enc_output,kv_cache=state[0 ],state[1 ] if kv_cache[self .idx] is None : k_and_v=X else : k_and_v=torch.cat((kv_cache[self .idx],X),dim=1 ) kv_cache[self .idx]=k_and_v X2=self .selfattention(X,k_and_v,k_and_v,True ) Y=self .addnorm1(X2,X) Y2=self .crossattention(Y,enc_output,enc_output,False ) Z=self .addnorm2(Y2,Y) return self .addnorm3(self .ffn(Z),Z),state
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class Decoder (nn.Module): def __init__ (self,vocab_size,num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,num_layers,bias=False ,**kwargs ): super ().__init__(**kwargs) self .num_hiddens=num_hiddens self .num_layers=num_layers self .embedding=nn.Embedding(vocab_size,num_hiddens) self .pe=PositionalEncoding(num_hiddens,dropout) self .blocks=nn.Sequential() for i in range (num_layers): self .blocks.add_module(f'decoder_block_{i} ' ,DecoderBlock(num_hiddens,LN_shape,FFN_num_hiddens,num_heads,dropout,idx=i,bias=bias)) self .dense=nn.Linear(num_hiddens,vocab_size) def init_state (self,enc_output ): return [enc_output,[None ]*self .num_layers] def forward (self,X,state ): X=self .pe(self .embedding(X)*math.sqrt(self .num_hiddens)) for block in self .blocks: X,state=block(X,state) return self .dense(X),state
1 2 3 4 5 6 7 8 9 10 11 12 class MyTransformerNet (nn.Module): def __init__ (self,encoder,decoder,**kwargs ): super ().__init__(**kwargs) self .encoder=encoder self .decoder=decoder def forward (self,enc_X,dec_X ): enc_outputs=self .encoder(enc_X) dec_state=self .decoder.init_state(enc_outputs) Y_hat,dec_state=self .decoder(dec_X,dec_state) return Y_hat,dec_state
模型部分结束,下面是数据预处理,训练和预测部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 import torchfrom torch import nnfrom torch.utils.data import Dataset, DataLoaderimport collectionsbatch_size = 128 num_steps = 40 lr = 0.001 epochs = 20 class RestoreVocab : def __init__ (self, tokens ): self .idx_to_token = ['<bos>' , '*' ] unique_tokens = list (set (tokens)) if '*' in unique_tokens: unique_tokens.remove('*' ) self .idx_to_token.extend(unique_tokens) self .token_to_idx = {token: idx for idx, token in enumerate (self .idx_to_token)} def __len__ (self ): return len (self .idx_to_token) def __getitem__ (self, key ): if not isinstance (key, (list , tuple )): return self .token_to_idx.get(key, 0 ) return [self .__getitem__(token) for token in key] lines = read_time_machine() corpus_chars = [char for line in lines for char in line] vocab = RestoreVocab(corpus_chars) class VowelRestoreDataset (Dataset ): def __init__ (self, corpus_chars, num_steps, vocab ): self .corpus = corpus_chars self .num_steps = num_steps self .vocab = vocab self .vowels = set ('aeiou' ) self .num_samples = len (corpus_chars) - num_steps def __len__ (self ): return self .num_samples def __getitem__ (self, idx ): chunk = self .corpus[idx : idx + self .num_steps] masked_chunk = ['*' if c in self .vowels else c for c in chunk] enc_X = self .vocab[masked_chunk] dec_X = [self .vocab['<bos>' ]] + self .vocab[chunk[:-1 ]] Y = self .vocab[chunk] return (torch.tensor(enc_X, dtype=torch.long), torch.tensor(dec_X, dtype=torch.long), torch.tensor(Y, dtype=torch.long)) train_loader = DataLoader(VowelRestoreDataset(corpus_chars, num_steps, vocab), batch_size=batch_size, shuffle=True , drop_last=True )
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 num_hiddens, num_heads, num_layers, dropout = 256 , 8 , 3 , 0.1 LN_shape = [num_hiddens] FFN_num_hiddens = num_hiddens * 4 encoder = Encoder(len (vocab), num_hiddens, LN_shape, FFN_num_hiddens, num_heads, dropout, num_layers) decoder = Decoder(len (vocab), num_hiddens, LN_shape, FFN_num_hiddens, num_heads, dropout, num_layers) net = MyTransformerNet(encoder, decoder).to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) loss_fn = nn.CrossEntropyLoss() print ("🚀 开始进行 '去元音复原' 训练..." )for e in range (epochs): net.train() total_loss = 0 for enc_X, dec_X, Y in train_loader: enc_X, dec_X, Y = enc_X.to(device), dec_X.to(device), Y.to(device) Y_hat, _ = net(enc_X, dec_X) l = loss_fn(Y_hat.reshape(-1 , Y_hat.shape[-1 ]), Y.reshape(-1 )) optimizer.zero_grad() l.backward() nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0 ) optimizer.step() total_loss += l.item() print (f'Epoch {e + 1 :02d} , Loss: {total_loss / len (train_loader):.4 f} ' )
🚀 开始在 M5 芯片上进行 '去元音复原' 训练...
Epoch 01, Loss: 1.4954
Epoch 02, Loss: 0.9292
Epoch 03, Loss: 0.5886
Epoch 04, Loss: 0.3895
Epoch 05, Loss: 0.2511
Epoch 06, Loss: 0.1716
Epoch 07, Loss: 0.1105
Epoch 08, Loss: 0.0815
Epoch 09, Loss: 0.0703
Epoch 10, Loss: 0.0623
Epoch 11, Loss: 0.0574
Epoch 12, Loss: 0.0531
Epoch 13, Loss: 0.0500
Epoch 14, Loss: 0.0473
Epoch 15, Loss: 0.0449
Epoch 16, Loss: 0.0423
Epoch 17, Loss: 0.0411
Epoch 18, Loss: 0.0393
Epoch 19, Loss: 0.0383
Epoch 20, Loss: 0.0370
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 def predict_restore (net, masked_text, vocab, device ): net.eval () seq_len = len (masked_text) enc_tokens = vocab[list (masked_text)] enc_X = torch.tensor([enc_tokens], dtype=torch.long, device=device) with torch.no_grad(): enc_outputs = net.encoder(enc_X) dec_state = net.decoder.init_state(enc_outputs) dec_X = torch.tensor([[vocab['<bos>' ]]], dtype=torch.long, device=device) outputs = [] for _ in range (seq_len): Y_hat, dec_state = net.decoder(dec_X, dec_state) next_token_idx = Y_hat[:, 0 , :].argmax(dim=-1 ).item() outputs.append(next_token_idx) dec_X = torch.tensor([[next_token_idx]], dtype=torch.long, device=device) return "" .join([vocab.idx_to_token[i] for i in outputs]) test_cases = [ "th* t*m* tr*v*ll*r s*t d*wn" , "h* m*d* * w*nd*rf*l d*sc*v*ry" , "*t w*s * d*rk *nd st*rmy n*ght" ] print ("\n==== 破译结果 ====" )for case in test_cases: print (f"输入 (密文) : {case } " ) print (f"输出 (还原) : {predict_restore(net, case , vocab, device)} \n" )
==== 破译结果 ====
输入 (密文) : th* t*m* tr*v*ll*r s*t d*wn
输出 (还原) : the time trave traveller sa
输入 (密文) : h* m*d* * w*nd*rf*l d*sc*v*ry
输出 (还原) : he made i wonderful discovery
输入 (密文) : *t w*s * d*rk *nd st*rmy n*ght
输出 (还原) : it was a a dark andark and sto
注:数据处理,训练,预测部分为Gemini生成,任务为填补句子中的元音(用星号表示)
后记 虽然Transformer在如今的深度学习领域是如同“Hello World”一般的存在,但经过今天的尝试,发现如果不调用torch中的组件,纯手写一个Transformer模型还是具有相当的难度的,代码中的各种细节,tensor形状,各类方法的参数,不同类之间的耦合,都需要大量的思考与梳理。但是,另一方面,这种“造轮子”的过程确实大大加深了我对于模型的理解。搞清楚整个模型在训练和预测时数据流动的通道,流动方式,经过了哪些类/函数,形状发生了怎样的改变,对于形成对模型的深刻理解是大有裨益的。甚至可以说,了解模型原理只占复现模型的30%,剩下的70%都藏在实现细节与设计哲学中。 这个Transformer模型的参数量大约只有5M,也就是0.005B,但在我的设备上训练仍然需要3min/epoch,看来深度学习真是一个极度耗费资源和算力的领域,如果说ML是“术业有专攻”,“精准击破”,那DL就是“一招鲜吃遍天”,是“力大砖飞”。