为什么是“从1开始”?因为RNN的逻辑实在太绕,而且正逐渐被Transformer取代,所以我直接调用了nn.LSTM,重点在于数据集处理和训练,预测函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import nn
import collections
import re

class SimpleRNNModel(nn.Module):
def __init__(self, vocab_size, num_hiddens):
super().__init__()
self.num_hiddens = num_hiddens
# 1. Embedding 层:把字符 ID 转成稠密向量
self.embedding = nn.Embedding(vocab_size, num_hiddens)
# 2. RNN 层:核心循环动力
self.rnn = nn.LSTM(num_hiddens, num_hiddens, batch_first=True,num_layers=2)
# 3. 输出层:把隐藏状态映射回词表大小,预测下一个字符
self.linear = nn.Linear(num_hiddens, vocab_size)

def forward(self, X, state):
# X 形状: (batch_size, num_steps)
X = self.embedding(X)
# output 形状: (batch_size, num_steps, num_hiddens)
output, state = self.rnn(X, state)
# 把时间步维度展平,一次性过线性层加速计算
# y 形状: (batch_size * num_steps, vocab_size)
y = self.linear(output.reshape(-1, output.shape[-1]))
return y, state

def begin_state(self, batch_size, device):
# 初始化隐藏状态 H 为 0
return (torch.zeros((2, batch_size, self.num_hiddens), device=device),
torch.zeros((2, batch_size, self.num_hiddens), device=device))
1
2
3
4
5
def read_time_machine():
# 这里假设你已经有了 txt 文件,或者直接从网络下载
with open('timemachine.txt', 'r') as f:
lines = f.readlines()
return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Vocab:
def __init__(self,tokens):
counter=collections.Counter(tokens)
self.idx_to_token=['<unk>']+[token for (token,freq) in counter.items()]
self.token_to_idx={token:idx for (idx,token) in enumerate(self.idx_to_token)}

def __len__(self):
return len(self.idx_to_token)

def __getitem__(self, key):
if not isinstance(key,(list,tuple)):
return self.token_to_idx.get(key,0)
else:
return [self.__getitem__(token) for token in key]
1
2
3
4
lines=read_time_machine()
tokens=[token for line in lines for token in line]
vocab=Vocab(tokens)
corpus=vocab[tokens]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.utils.data import Dataset,DataLoader

class TMDataset(Dataset):
def __init__(self,corpus,num_steps):
self.corpus=corpus
self.num_steps=num_steps
self.num_samples=len(corpus)-num_steps

def __len__(self):
return self.num_samples

def __getitem__(self,idx):
X=torch.tensor(self.corpus[idx:idx+self.num_steps])
y=torch.tensor(self.corpus[idx+1:idx+1+self.num_steps])
return X,y
1
2
3
4
batch_size=64
num_steps=32
lr=0.01
epochs=20
1
2
3
dataset=TMDataset(corpus,num_steps)

train_loader=DataLoader(dataset,batch_size,shuffle=True,drop_last=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
net=SimpleRNNModel(vocab_size=len(vocab),num_hiddens=256)
loss=nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
device=torch.device('mps')
net.to(device)

for e in range(epochs):
net.train()
tloss=0
for X,y in train_loader:
X,y=X.to(device),y.to(device)
state=net.begin_state(batch_size,device)
y_pred,_=net(X,state)
l=loss(y_pred,y.reshape(-1).long())
optimizer.zero_grad()
l.backward()
nn.utils.clip_grad_norm_(net.parameters(),max_norm=1)
optimizer.step()
tloss+=l.item()

scheduler.step()
print(f'epoch {e + 1}, loss {tloss / len(train_loader):.4f}')

epoch 1, loss 1.6600
epoch 2, loss 1.1736
epoch 3, loss 0.9350
epoch 4, loss 0.7585
epoch 5, loss 0.6386
epoch 6, loss 0.5594
epoch 7, loss 0.5061
epoch 8, loss 0.4689
epoch 9, loss 0.4426
epoch 10, loss 0.4229
epoch 11, loss 0.4076
epoch 12, loss 0.3953
epoch 13, loss 0.3854
epoch 14, loss 0.3770
epoch 15, loss 0.3701
epoch 16, loss 0.3644
epoch 17, loss 0.3598
epoch 18, loss 0.3562
epoch 19, loss 0.3536
epoch 20, loss 0.3520
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def predict(prefix,num_preds,net,vocab,device):
outputs=[]
net.eval()
outputs.append(vocab[prefix[0]])
state=net.begin_state(1,device)

def get_last_token():
return torch.tensor([outputs[-1]],device=device,dtype=torch.long).reshape((1,1))

#warmup
for i in range(1,len(prefix)):
_,state=net(get_last_token(),state)
outputs.append(vocab[prefix[i]])

#predict
for j in range(num_preds):
y_hat,state=net(get_last_token(),state)
outputs.append(y_hat.argmax(dim=1).item())

return ''.join([vocab.idx_to_token[i] for i in outputs])
1
predict('the',200,net,vocab,device)
'the sun grow larger and dullerin the westward sky and the life of the old earth ebb away atlast more than thirty million years hence the huge red hot dome ofthe sun had come to obscure nearly a tenth par'