基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能
本文使用循环神经网络的类似于谷歌的智能自动补全。
引言
整体想法
本教程分为3部分:
构建和训练字符 RNN 模型
构建前端和后端
通过 WebSocket 建立通信
准备模型
数据集
该数据集专为问答相关任务而设计,但是我们只能过滤掉问题。
questions_all[10:15]array(['Did Lincoln start his political career in 1832?','Did Lincoln ever represent Alton & Sangamon Railroad?','Which county was Lincoln born in?','When did Lincoln first serve as President?','Who assassinated Lincoln?'], dtype=object)
数据预处理
def preprocess_text(question):question = question.replace("S08_", "")question = question.replace("NOTTTT FOUND", "")question = re.sub("([\(\[]).*?([\)\]])", "", question) # remove text between bracketsquestion = re.sub("[^\w\s]", "", question) # remove punctuationsre.sub("\s+", " ", question) # remove multiple white spacesre.sub("[\t\n]", "", question) # remove tabs and newline charactersquestion = question.lower().strip()return questionquestions_processed = [preprocess_text(q) for q in questions_all]questions_processed = [q for q in questions_processed if len(q) != 0] # remove empty strings after preprocessing
上面代码片段为数据处理。
数据统计
# calculate mean and standard deviationmean = np.mean(q_lengths)std = np.std(q_lengths)print(f"mean: {mean}, std: {std}")# mean: 52.34, std: 29.05# compute the optimum length of questions,# which is the length one standard deviation away from the meanoptimum_length = int(mean + std)print(f"optimum length: {optimum_length}")# optimum length: 81# Only keep the questions with length shorter than the optimum lengthquestions_short = [q for q in questions_processed if len(q) <= optimum_length]
准备数据集
对数据有更多的控制权。
它有助于保持代码模块化。
我们可以从这个数据集实例创建一个 Pytorch 数据加载器,它会自动处理批处理、混洗和采样数据,稍后会看到。
class QuestionsDataset(Dataset):def __init__(self, questions, vocab, sos_token, eos_token, batch_first=False):# initialize parametersself.sos_idx = 0self.eos_idx = 1self.int2char = {self.sos_idx: sos_token, self.eos_idx: eos_token} # insert start of sentence and end of sentence tokensself.int2char.update({idx: char for idx, char in enumerate(vocab, start=self.eos_idx+1)})self.char2int = {char: idx for idx, char in self.int2char.items()}self.n_chars = len(self.int2char)# encode and pad questionsself.questions_encoded = pad_sequence([self.encode_question(q) for q in questions], \batch_first=batch_first)def __len__(self):return len(self.questions_encoded)def __getitem__(self, idx):return self.questions_encoded[idx]def encode_question(self, question):'''encode question as char indices and perform one-hot encoding'''question_encoded = [self.sos_idx] # append sosfor char in question:question_encoded.append(self.char2int[char])question_encoded.append(self.eos_idx) # append eosreturn F.one_hot(torch.tensor(question_encoded, dtype=torch.long), self.n_chars).float()
训练,验证拆分
# Define Parametersvocab = sorted(set("".join(questions_short)))sos_token = '['eos_token = ']'BATCH_FIRST=TrueBATCH_SIZE=64device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# train-validation splitval_percent = 0.1n_val = int(val_percent * len(questions_short))n_train = len(questions_short) - n_valprint(f"n_train: {n_train}, n_val: {n_val}")# train: 1977, val: 219questions_train = questions_short[:n_train]questions_val = questions_short[n_train:]# Create Datasets and Dataloaderstrain_dataset = QuestionsDataset(questions_train, vocab, sos_token, eos_token, batch_first=BATCH_FIRST)val_dataset = QuestionsDataset(questions_val, vocab, sos_token, eos_token, batch_first=BATCH_FIRST)train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
构建 charRNN 模型
class charRNN(nn.Module):def __init__(self, VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS=2, P_DROPOUT=0.5, batch_first=False):super().__init__()self.HIDDEN_SIZE = HIDDEN_SIZEself.N_LAYERS = N_LAYERSself.lstm = nn.LSTM(VOCAB_SIZE, HIDDEN_SIZE, batch_first=batch_first,dropout=P_DROPOUT, num_layers=N_LAYERS)self.dropout = nn.Dropout(P_DROPOUT)self.fc = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE)def forward(self, inputs, hidden):lstm_out, hidden = self.lstm(inputs, hidden)# flatten the lstm outputlstm_out = torch.flatten(lstm_out, start_dim=0, end_dim=1)out = self.dropout(lstm_out)out = self.fc(out)return out, hiddendef init_hidden(self, BATCH_SIZE, device):hidden = (torch.zeros((self.N_LAYERS, BATCH_SIZE, self.HIDDEN_SIZE), dtype=torch.float32).to(device),torch.zeros((self.N_LAYERS, BATCH_SIZE, self.HIDDEN_SIZE), dtype=torch.float32).to(device))return hidden
请注意,在将输出通过全连接网络之前对其进行了展平。这是因为每当计算损失时,我们都会对所有批次一起进行计算,而不是逐批次进行。所以沿批次维度展平以组合所有批次,如下所示:
训练模型
# define model hyperparametersVOCAB_SIZE=train_dataset.n_chars # size of the vocabularyHIDDEN_SIZE=512 # the size of the hidden state vectorN_LAYERS=3 # number of stacked RNN layersP_DROPOUT = 0.4 # dropout probability# create the modelmodel = charRNN(VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS, P_DROPOUT, BATCH_FIRST)# define training hyperparametersn_epochs = 100optimizer = optim.Adam(model.parameters())loss = nn.CrossEntropyLoss()clip = 5save_dir = "./saved_models"save_epoch = 10# train the modeltrain_loss_list = []val_loss_list = []for epoch in tqdm(range(n_epochs)):# training# -------------n_batches_train = 0cummulative_loss_train = 0model.train()# initialize hidden statehidden = model.init_hidden(BATCH_SIZE, device)for data_batch in train_dataloader:# detach hidden statehidden = tuple([h.detach() for h in hidden])if data_batch.shape[0] != BATCH_SIZE:continue# get datatargets = data_batch[:, :-1, :].to(device), data_batch[:, 1:, :].to(device)# get predictionshidden = model(labels, hidden)# compute losstarget_idx = torch.argmax(targets, dim=2).long()target_flatten = torch.flatten(target_idx, start_dim=0, end_dim=1)train_loss = loss(preds, target_flatten)# backpropagationoptimizer.zero_grad()train_loss.backward()# clip the gradient before updating the weightsclip)optimizer.step()n_batches_train += 1cummulative_loss_train += train_loss.item()loss_per_epoch_train = cummulative_loss_train / n_batches_traintrain_loss_list.append(loss_per_epoch_train)# validation# ---------------n_batches_val = 0cummulative_loss_val = 0model.eval()hidden = model.init_hidden(BATCH_SIZE, device)for data_batch in val_dataloader:if data_batch.shape[0] != BATCH_SIZE:continue# get datatargets = data_batch[:, :-1, :].to(device), data_batch[:, 1:, :].to(device)# get predictionswith torch.no_grad():hidden = model(labels, hidden)# compute losstarget_idx = torch.argmax(targets, dim=2).long()target_flatten = torch.flatten(target_idx, start_dim=0, end_dim=1)val_loss = loss(preds, target_flatten)n_batches_val += 1cummulative_loss_val += val_loss.item()loss_per_epoch_val = cummulative_loss_val / n_batches_valval_loss_list.append(loss_per_epoch_val)# save model every 10 epochsif epoch % save_epoch == 0:model_name = f"charRNN_questions_epoch_{epoch}.pt"save_path = os.path.join(save_dir, model_name)save_path)
我们跨批次传递隐藏状态,这意味着一个批次的最终状态将是下一批的初始状态。
我们在更新模型参数之前对梯度进行了裁剪,因为 RNN 面临着随着时间的反向传播而导致梯度爆炸的问题。因此将幅度较大的梯度“裁剪”到特定阈值。
训练后,我们可以绘制训练和验证曲线:
产生问题
class GenerateText:def __init__(self, model, k, int2char, char2int, device):self.int2char = int2charself.char2int = char2intself.n_chars = len(int2char)self.model = modelself.device = deviceself.k = kself.sos_token = self.int2char[0]self.eos_token = self.int2char[1]def predict_next_char(self, hidden, input_char):# encode charchar_one_hot = self.encode_char(input_char)# get the predictionswith torch.no_grad():out, hidden = self.model(char_one_hot, hidden)# convert the output to a character probability distributionp = F.softmax(out, dim=1)# move to cpu as numpy doesn't support gpup = p.cpu()# get top k characters from the distributionvalues, indices = p.topk(self.k)indices = indices.squeeze().numpy()values = values.squeeze().numpy()# sample any char from the top k chars using the output softmax distributionchar_pred = np.random.choice(indices, size=1, p=values/values.sum())return self.int2char[char_pred[0]], hiddendef generate_text(self, prime, max_chars=80):# append start tokenprime = self.sos_token + primeall_chars = [char for char in prime]hidden = model.init_hidden(1, self.device)# build up the hidden state using the initial primefor char in prime:char_pred, hidden = self.predict_next_char(hidden, char)all_chars.append(char_pred)# generate n charsc = len(all_chars)while char_pred != self.eos_token:if c == max_chars:breakchar_pred, hidden = self.predict_next_char(hidden, all_chars[-1])all_chars.append(char_pred)c += 1return "".join(all_chars)def encode_char(self, char):char_int = self.char2int[char]char_one_hot = F.one_hot(torch.tensor(char_int), self.n_chars).float()return char_one_hot.unsqueeze(0).unsqueeze(0).to(self.device)
text_generator.generate_text('was abraham lin')# was abraham lincoln the first president of the united kingdomtext_generator.generate_text('did lincoln')# did lincoln born in language familytext_generator.generate_text('when di')# when did the election of the lowate of the korean lengthenedtext_generator.generate_text('who i')# who is the modern made of chinesetext_generator.generate_text('is it')# is it true that indonesia has a political citizentext_generator.generate_text('who determined')# who determined the dependence of the modern piano
准备前端
<template><div tabindex="1" @focus="setCaret" class="autocomplete-container"><span ref="editbar" class="editable" contenteditable="true">this is auto</span><span class="placeholder" contenteditable="false">this is autocomplete</span></div></template><style>.autocomplete-container {position: relative;}span {display: block;min-width: 1px;outline: none;}.editable {position: absolute;left: 8px;top: 5px;}.placeholder {color: gray;position: absolute;left: 8px;top: 5px;z-index: -1;}</style>
通过 WebSocket 进行通信
构建后端
from fastapi import FastAPI, WebSocketfrom model import charRNN, GenerateTextfrom config import *app = FastAPI()# load modelmodel = charRNN(VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS, P_DROPOUT, BATCH_FIRST)model.load_state_dict(torch.load(PATH, map_location=device))model.eval()text_gen = GenerateText(model, int2char, char2int, device)async def predict_question(websocket: WebSocket):await websocket.accept()while True:input_text = await websocket.receive_text()autocomplete_text = text_gen.generate_text(input_text)await websocket.send_text(autocomplete_text)
构建前端
<template><div class="pad-container"><div tabindex="1" @focus="setCaret" class="autocomplete-container"><span @input="sendText" @keypress="preventInput" ref="editbar" class="editable" contenteditable="true"></span><span class="placeholder" contenteditable="false">{{autoComplete}}</span></div></div></template><script>export default {mounted() {this.connection = new WebSocket(process.env.VUE_APP_URL);this.connection.onopen = () => console.log("connection established");this.connection.onmessage = this.receiveText;},methods: {sendText() {const inputText = this.$refs.editbar.textContent;this.connection.send(inputText);},receiveText(event) {this.autoComplete = event.data;}}</script>view raw
跑起来
kubectl apply -f ./backend/backend-deployment.yamlkubectl apply -f ./backend/backend-service.yamlkubectl apply -f ./frontend/frontend-deployment.yamlkubectl apply -f ./frontend/frontend-service.yaml
查看在 localhost:8080 运行的应用程序
OK,到此结束。
