vlambda博客
学习文章列表

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

本文使用循环神经网络的类似于谷歌的智能自动补全。

引言

本文尝试使用深度学习和循环神经网络构建搜搜索的自动补全。

整体想法

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

这个想法是,一旦用户开始在搜索栏中输入内容,我们就将输入文本逐个字符输入到循环神经网络中,以生成自动补全预测。

本教程分为3部分:

  • 构建和训练字符 RNN 模型

  • 构建前端和后端

  • 通过 WebSocket 建立通信

准备模型

首先构建 charRNN 模型。

数据集

现在想要训练 charRNN 模型来预测用户要问的问题——所以在问题数据集上训练模型。
从 Kaggle 中找到了以下数据集:

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

该数据集专为问答相关任务而设计,但是我们只能过滤掉问题。

>> questions_all[10:15]>> array(['Did Lincoln start his political career in 1832?', 'Did Lincoln ever represent Alton & Sangamon Railroad?', 'Which county was Lincoln born in?', 'When did Lincoln first serve as President?', 'Who assassinated Lincoln?'], dtype=object)
上面的代码片段展示了数据集中的一小部分问题。该数据集包含来自各种主题的问题,它可以作为我们app的一个起点,因为问题很短且有点类似于用户向 Google 提出的问题。

数据预处理

现在已经过滤掉了问题,下一步是执行预处理,包括将文本转换为小写;删除标点符号、空值、制表符、换行符和多个空格。
def preprocess_text(question): question = question.replace("S08_", "") question = question.replace("NOTTTT FOUND", "") question = re.sub("([\(\[]).*?([\)\]])", "", question) # remove text between brackets question = re.sub("[^\w\s]", "", question) # remove punctuations re.sub("\s+", " ", question) # remove multiple white spaces re.sub("[\t\n]", "", question) # remove tabs and newline characters question = question.lower().strip() return question questions_processed = [preprocess_text(q) for q in questions_all]questions_processed = [q for q in questions_processed if len(q) != 0] # remove empty strings after preprocessing

上面代码片段为数据处理。

数据统计

目前还没有完成预处理。看一下问题长度的分布图:

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

从上图中,我们可以看到平均一个问题有大约 50 个字符,但也有超过 200 个字符的问题。如果您考虑一下,用户很少搜索这么长的问题,所以完全删除它们,因为平均问题只有 50 个字符长。现在可以保留与平均值相差一个标准差的问题,以包括大部分数据集并删除其他内容。
# calculate mean and standard deviationmean = np.mean(q_lengths)std = np.std(q_lengths)print(f"mean: {mean}, std: {std}")
# mean: 52.34, std: 29.05
# compute the optimum length of questions,# which is the length one standard deviation away from the meanoptimum_length = int(mean + std)print(f"optimum length: {optimum_length}")
# optimum length: 81
# Only keep the questions with length shorter than the optimum lengthquestions_short = [q for q in questions_processed if len(q) <= optimum_length]
在上面的代码片段计算了分布的均值和标准差,并将它们相加得到最佳长度,即距离均值一个标准差的长度。最后删除超过最佳长度的问题。

准备数据集

在 Pytorch 中通常创建一个继承自 Pytorch 的 Dataset 类的自定义数据集类,这样做有以下几个优点:
  • 对数据有更多的控制权。

  • 它有助于保持代码模块化。

  • 我们可以从这个数据集实例创建一个 Pytorch 数据加载器,它会自动处理批处理、混洗和采样数据,稍后会看到。

class QuestionsDataset(Dataset): def __init__(self, questions, vocab, sos_token, eos_token, batch_first=False):  # initialize parameters self.sos_idx = 0 self.eos_idx = 1 self.int2char = {self.sos_idx: sos_token, self.eos_idx: eos_token} # insert start of sentence and end of sentence tokens self.int2char.update({idx: char for idx, char in enumerate(vocab, start=self.eos_idx+1)}) self.char2int = {char: idx for idx, char in self.int2char.items()} self.n_chars = len(self.int2char)  # encode and pad questions self.questions_encoded = pad_sequence([self.encode_question(q) for q in questions], \ batch_first=batch_first)  def __len__(self): return len(self.questions_encoded)  def __getitem__(self, idx): return self.questions_encoded[idx]  def encode_question(self, question): ''' encode question as char indices and perform one-hot encoding ''' question_encoded = [self.sos_idx] # append sos for char in question: question_encoded.append(self.char2int[char]) question_encoded.append(self.eos_idx) # append eos return F.one_hot(torch.tensor(question_encoded, dtype=torch.long), self.n_chars).float()
上面的代码片段展示了 QuestionsDataset 的自定义数据集类:
首先在 __init__ 方法中,加载通过从数据集中获取唯一字符而准备的问题和词汇表。我们还添加了一个额外的开始和结束标记,因为每当处理有限序列时,模型应该知道何时开始和结束句子。
encode_questions 方法用于对单个问题进行编码。这里首先使用我们在 __init__ 方法中创建的词汇字典将问题中的字符编码为索引,然后执行 one-hot 编码。请注意在 one-hot 编码之前分别在句子的开头和结尾附加了开始和结束标记。也可以使用嵌入代替 one-hot 编码,但由于这里的词汇量很小,可以继续使用 one-hot 编码。
回到 __init__ 方法,对所有问题进行编码,用零填充它们,以使它们具有相同的长度。Padding 可对问题进行批处理,这可以在不影响模型性能的情况下缩短训练时间。
还需要定义两个额外的方法:__len__ 方法应该返回数据集中数据点的总数,__getitem__ 方法应该根据索引返回一个数据点。Pytorch 的 Dataloader 使用这些方法来批处理和打乱数据。
编码过程如下图所示:

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

训练,验证拆分

接下来将数据拆分为训练集和验证集,创建数据加载器。
# Define Parametersvocab = sorted(set("".join(questions_short)))sos_token = '['eos_token = ']'BATCH_FIRST=TrueBATCH_SIZE=64device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# train-validation splitval_percent = 0.1n_val = int(val_percent * len(questions_short))n_train = len(questions_short) - n_valprint(f"n_train: {n_train}, n_val: {n_val}")# train: 1977, val: 219
questions_train = questions_short[:n_train]questions_val = questions_short[n_train:]
# Create Datasets and Dataloaderstrain_dataset = QuestionsDataset(questions_train, vocab, sos_token, eos_token, batch_first=BATCH_FIRST)val_dataset = QuestionsDataset(questions_val, vocab, sos_token, eos_token, batch_first=BATCH_FIRST)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
在上面的代码片段中,首先通过从所有问题中获取唯一字符来创建词汇表,定义批大小以及可以是词汇表中不存在的任何唯一字符的开始和结束标记。
接下来以 9:1 的比例构建训练验证拆分,使用我们之前定义的 QuestionsDataset 类创建数据集。Pytorch 的 Dataloader 构造函数从这些数据集实例创建数据加载器,这些实例可以自动批处理、采样和打乱数据。
最后准备好将数据输入模型。接下来构建模型。

构建 charRNN 模型

charRNN 模型在每个时间步接收一个输入字符,并输出下一个适当字符的概率分布,如下图所示:

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

这个想法是将一个编码问题传递给循环神经网络,该网络在每个时间步输出一个隐藏状态。接下来将这些隐藏状态中的每一个通过一个全连接网络来获取 logits。然后根据 logits 和目标计算交叉熵损失。在每个时间步计算损失,将它们添加到所有时间步,以获得将用于整个网络的反向传播的总损失。
模型的代码:
class charRNN(nn.Module):  def __init__(self, VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS=2, P_DROPOUT=0.5, batch_first=False): super().__init__() self.HIDDEN_SIZE = HIDDEN_SIZE self.N_LAYERS = N_LAYERS self.lstm = nn.LSTM(VOCAB_SIZE, HIDDEN_SIZE, batch_first=batch_first,  dropout=P_DROPOUT, num_layers=N_LAYERS) self.dropout = nn.Dropout(P_DROPOUT) self.fc = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE)  def forward(self, inputs, hidden): lstm_out, hidden = self.lstm(inputs, hidden)  # flatten the lstm output lstm_out = torch.flatten(lstm_out, start_dim=0, end_dim=1)  out = self.dropout(lstm_out) out = self.fc(out)  return out, hidden  def init_hidden(self, BATCH_SIZE, device): hidden = (torch.zeros((self.N_LAYERS, BATCH_SIZE, self.HIDDEN_SIZE), dtype=torch.float32).to(device), torch.zeros((self.N_LAYERS, BATCH_SIZE, self.HIDDEN_SIZE), dtype=torch.float32).to(device)) return hidden

请注意,在将输出通过全连接网络之前对其进行了展平。这是因为每当计算损失时,我们都会对所有批次一起进行计算,而不是逐批次进行。所以沿批次维度展平以组合所有批次,如下所示:

基于 Vue.js、FastAPI 和 WebSockets 在浏览器中构建基于 AI 的自动补全功能

训练模型

# define model hyperparametersVOCAB_SIZE=train_dataset.n_chars # size of the vocabularyHIDDEN_SIZE=512 # the size of the hidden state vectorN_LAYERS=3 # number of stacked RNN layersP_DROPOUT = 0.4 # dropout probability
# create the modelmodel = charRNN(VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS, P_DROPOUT, BATCH_FIRST)
# define training hyperparametersn_epochs = 100optimizer = optim.Adam(model.parameters())loss = nn.CrossEntropyLoss()clip = 5
save_dir = "./saved_models"save_epoch = 10
# train the modeltrain_loss_list = []val_loss_list = []
for epoch in tqdm(range(n_epochs)): # training # -------------
n_batches_train = 0 cummulative_loss_train = 0 model.train() # initialize hidden state hidden = model.init_hidden(BATCH_SIZE, device) for data_batch in train_dataloader: # detach hidden state hidden = tuple([h.detach() for h in hidden]) if data_batch.shape[0] != BATCH_SIZE: continue # get data labels, targets = data_batch[:, :-1, :].to(device), data_batch[:, 1:, :].to(device) # get predictions preds, hidden = model(labels, hidden) # compute loss target_idx = torch.argmax(targets, dim=2).long() target_flatten = torch.flatten(target_idx, start_dim=0, end_dim=1) train_loss = loss(preds, target_flatten)
# backpropagation optimizer.zero_grad() train_loss.backward() # clip the gradient before updating the weights clip_grad_norm_(model.parameters(), clip) optimizer.step() n_batches_train += 1 cummulative_loss_train += train_loss.item() loss_per_epoch_train = cummulative_loss_train / n_batches_train train_loss_list.append(loss_per_epoch_train) # validation # --------------- n_batches_val = 0 cummulative_loss_val = 0 model.eval() hidden = model.init_hidden(BATCH_SIZE, device) for data_batch in val_dataloader: if data_batch.shape[0] != BATCH_SIZE: continue
# get data labels, targets = data_batch[:, :-1, :].to(device), data_batch[:, 1:, :].to(device) # get predictions with torch.no_grad(): preds, hidden = model(labels, hidden) # compute loss target_idx = torch.argmax(targets, dim=2).long() target_flatten = torch.flatten(target_idx, start_dim=0, end_dim=1) val_loss = loss(preds, target_flatten) n_batches_val += 1 cummulative_loss_val += val_loss.item() loss_per_epoch_val = cummulative_loss_val / n_batches_val val_loss_list.append(loss_per_epoch_val) # save model every 10 epochs if epoch % save_epoch == 0: model_name = f"charRNN_questions_epoch_{epoch}.pt" save_path = os.path.join(save_dir, model_name) torch.save(model.state_dict(), save_path)
定义了一个三层 LSTM 网络,隐藏状态大小为 512,dropout 为 0.4。
使用了默认学习率为 1e-3 和交叉熵作为损失函数的 Adam 优化器。训练和验证循环运行 100 个 epoch,模型每 10 个 epoch 保存一次。这些是可以根据模型的性能进行调整的超参数。
另外,这里有几点需要注意:
  • 我们跨批次传递隐藏状态,这意味着一个批次的最终状态将是下一批的初始状态。

  • 我们在更新模型参数之前对梯度进行了裁剪,因为 RNN 面临着随着时间的反向传播而导致梯度爆炸的问题。因此将幅度较大的梯度“裁剪”到特定阈值。

训练后,我们可以绘制训练和验证曲线:

仔细观察,训练和验证损失彼此接近。这意味着模型无法捕获数据中的表示。这是有道理的,因为我们使用的是简单的 charRNN 模型。如果我们使用基于transformer的模型可能会提高性能,但现在继续使用这个模型。

产生问题

现在模型已经训练好了,了解如何进行推理,在我们的例子中是根据用户的输入来预测问题。
这个想法是给定一个输入字符,我们对其进行编码并通过网络获得logits作为输出。然后我们在 logits 上应用 softmax 函数来获得下一个字符的概率分布。现在可以直接从这个分布中挑选出前k个字符,但是这样会导致过拟合,所以从分布中挑选出前k个字符,并随机选择其中一个作为下一个字符。
此过程在以下代码片段的 GenerateText 类中的 predict_next_char 方法中进行了说明。
class GenerateText: def __init__(self, model, k, int2char, char2int, device): self.int2char = int2char self.char2int = char2int self.n_chars = len(int2char) self.model = model self.device = device self.k = k self.sos_token = self.int2char[0] self.eos_token = self.int2char[1]  def predict_next_char(self, hidden, input_char):  # encode char char_one_hot = self.encode_char(input_char)
# get the predictions with torch.no_grad(): out, hidden = self.model(char_one_hot, hidden) # convert the output to a character probability distribution p = F.softmax(out, dim=1)
# move to cpu as numpy doesn't support gpu p = p.cpu()
# get top k characters from the distribution values, indices = p.topk(self.k)
indices = indices.squeeze().numpy() values = values.squeeze().numpy()
# sample any char from the top k chars using the output softmax distribution char_pred = np.random.choice(indices, size=1, p=values/values.sum())
return self.int2char[char_pred[0]], hidden def generate_text(self, prime, max_chars=80): # append start token prime = self.sos_token + prime
all_chars = [char for char in prime] hidden = model.init_hidden(1, self.device)
# build up the hidden state using the initial prime for char in prime: char_pred, hidden = self.predict_next_char(hidden, char)
all_chars.append(char_pred)
# generate n chars c = len(all_chars) while char_pred != self.eos_token: if c == max_chars: break char_pred, hidden = self.predict_next_char(hidden, all_chars[-1]) all_chars.append(char_pred) c += 1
return "".join(all_chars) def encode_char(self, char): char_int = self.char2int[char] char_one_hot = F.one_hot(torch.tensor(char_int), self.n_chars).float() return char_one_hot.unsqueeze(0).unsqueeze(0).to(self.device)
predict_next_char 方法采用输入字符和隐藏状态来预测下一个字符。
现在的目标是根据用户的上下文预测问题。上下文是一组初始字符。这个想法是将这些初始字符一个一个地输入模型并建立隐藏状态。接下来,使用隐藏状态和上下文中的最后一个字符,使用 predict_next_char 方法预测下一个字符,然后我们将这个预测的字符作为输入来预测下一个字符。重复这个过程,直到到达结束标记,表明已经到达问题的结尾。
该过程在 generate_text 方法中进行了说明。
看看下面的一些例子:
text_generator.generate_text('was abraham lin')# was abraham lincoln the first president of the united kingdom
text_generator.generate_text('did lincoln')# did lincoln born in language family
text_generator.generate_text('when di')# when did the election of the lowate of the korean lengthened
text_generator.generate_text('who i')# who is the modern made of chinese
text_generator.generate_text('is it')# is it true that indonesia has a political citizen
text_generator.generate_text('who determined')# who determined the dependence of the modern piano

准备前端

前端很简单,只是一个搜索栏,前面显示用户输入,后面显示自动完成,降低了不透明度,如下图所示:

为了达到这种效果,我们的想法是在同一个 div 容器中放置两个 span 元素,这可以通过将 div 容器的位置设置为 relative 并将 span 元素的位置设置为 absolute 来完成。本质上将 span 元素相对于 div 容器放置,以便它们可以放在相同的位置。
代码如下所示:
<template> <div tabindex="1" @focus="setCaret" class="autocomplete-container"> <span ref="editbar" class="editable" contenteditable="true">this is auto</span> <span class="placeholder" contenteditable="false">this is autocomplete</span>  </div></template>

<style>.autocomplete-container { position: relative;}span { display: block; min-width: 1px; outline: none;}.editable { position: absolute; left: 8px; top: 5px;}.placeholder { color: gray; position: absolute; left: 8px; top: 5px; z-index: -1;}</style>
现在可以通过将其 contenteditable 属性设置为 true 来使 span 元素在浏览器中可编辑。由于希望在用户输入后自动完成,因此它的 z-index 设置为 -1。最后加入一些 CSS 来获得漂亮的自动补全搜索栏。

通过 WebSocket 进行通信

这个想法是,一旦用户开始在浏览器中输入,文本就需要作为输入发送到后端的 charRNN 模型以生成预测。然后预测将被发送回前端并显示在搜索栏中用户输入后面的自动补全占位符中。
在前端和后端之间建立实时的双向通信,WebSockets 是理想的选择。
接下来构建前端、后端,并使用 WebSockets 建立通信。

构建后端

对于后端只需要围绕 charRNN 模型创建一个 API 包装器来服务请求。在这里使用 FastAPI 来定义 WebSocket API。
看一下代码:
from fastapi import FastAPI, WebSocket
from model import charRNN, GenerateTextfrom config import *
app = FastAPI()
# load modelmodel = charRNN(VOCAB_SIZE, HIDDEN_SIZE, N_LAYERS, P_DROPOUT, BATCH_FIRST)model.load_state_dict(torch.load(PATH, map_location=device))model.eval()
text_gen = GenerateText(model, int2char, char2int, device)
@app.websocket("/")async def predict_question(websocket: WebSocket): await websocket.accept() while True: input_text = await websocket.receive_text() autocomplete_text = text_gen.generate_text(input_text) await websocket.send_text(autocomplete_text)
首先定义并加载预训练的 charRNN 模型并将其设置为评估模式。接下来定义一个使用 WebSocket 协议的路由和一个函数 predict_question,只要请求到达该路由,就会调用该函数。
predict_question 函数是一个异步函数,它不是等待特定任务完成,而是执行其他部分并在任务完成时返回。 
在这个函数中,首先接受来自前端服务器的握手并建立通信。然后,一旦从前端接收到输入文本,就会通过 charRNN 模型运行以生成自动补全预测并使用 WebSocket 协议将其发送回前端。
接下来在前端添加 WebSocket 通信机制。

构建前端

在前端需要一种机制来检测用户输入,以便可以在用户开始输入时立即将文本发送到后端。为此可以使用 Vue 指令和事件处理程序。v-on 指令(也由@符号表示)监听元素上的特定事件,并在事件触发时调用处理函数。
看一下代码:
<template><div class="pad-container"> <div tabindex="1" @focus="setCaret" class="autocomplete-container"> <span @input="sendText" @keypress="preventInput" ref="editbar" class="editable" contenteditable="true"></span> <span class="placeholder" contenteditable="false">{{autoComplete}}</span>  </div></div></template>
<script>export default { mounted() { this.connection = new WebSocket(process.env.VUE_APP_URL); this.connection.onopen = () => console.log("connection established"); this.connection.onmessage = this.receiveText; }, methods: { sendText() { const inputText = this.$refs.editbar.textContent; this.connection.send(inputText); }, receiveText(event) { this.autoComplete = event.data; }}</script>view raw
首先在挂载的生命周期钩子hook中创建 WebSocket 对象。生命周期钩子是在组件创建的不同阶段调用的函数。一旦组件渲染到 DOM 中,就会调用挂载的生命周期钩子,这是实例化 WebSocket 对象并与后端建立通信的好阶段。
接下来,在搜索栏元素中添加行 @input="sendText" 来监听输入事件并触发函数 sendText 抓取搜索栏中的文本并将其发送到后端。一旦收到自动补全从后端预测,函数 receiveText 通过 WebSocket 对象的 onmessage 回调方法触发,该函数填充 autoComplete 占位符,该占位符显示在搜索栏中的用户输入后面。
请注意搜索栏元素中的 @keypress=”preventInput” 行。这是一个事件处理程序,可防止用户在搜索栏中输入数值或标点符号,在 charRNN 模型的预处理步骤中,我们已经删除了仅保留字母字符的所有内容。因此,为了防止无法识别的输入进入我们的 charRNN 模型,使用了这个事件处理程序。
这样完成了应用程序。看看它的实际效果。

跑起来

假设读者已经安装了 Kubernetes,请按顺序运行以下命令:
kubectl apply -f ./backend/backend-deployment.yamlkubectl apply -f ./backend/backend-service.yamlkubectl apply -f ./frontend/frontend-deployment.yamlkubectl apply -f ./frontend/frontend-service.yaml

查看在 localhost:8080 运行的应用程序

OK,到此结束。