vlambda博客
学习文章列表

如何用 Rust 实现朴素贝叶斯分类器

原文链接: https://www.freecodecamp.org/news/implement-naive-bayes-with-rust/

翻译:Ch3nYe

选题:Ch3nYe

本文由 Rustt 翻译,RustCn 荣誉推出

如何用 Rust 实现朴素贝叶斯分类器

我想精进我的 Rust 编程技能,同时也帮助你磨练技能。所以我决定写一系列关于 Rust 编程语言的文章。

在实际使用 Rust 写东西的时候,我能学到很广泛的技术概念。在本期中,我们将学习如何使用 Rust 实现朴素贝叶斯分类器。

您可能会在本文中遇到一些不熟悉的术语或概念。不要气馁,如果你有时间可以自行学习,但无论如何,希望你不要偏离本文的主要思路。

什么是朴素贝叶斯分类器

朴素贝叶斯分类器是一个基于贝叶斯理论的机器学习算法。贝叶斯理论是一种通过给定一些数据 D ,来更新一个假设 H 的概率的方法。

数学表达为:

$$ P(H \mid D)=\frac{P(D \mid H) P(H)}{P(D)} $$

$P(H|D)$ 是给出数据 D 假设 H 成立的概率。

如果我们统计更多数据,就可以根据这些数据更新$P(H|D)$ 。

朴素贝叶斯模型基于一个大假设:数据集中是否存在数据点与该数据集中已经存在的数据无关(参考)。也就是说,每条数据之间相互独立。

显然,这个假设是比较弱的,现实中难以完全成立。但它仍然很有用,它允许我们创建一个用起来还不错的的高效分类器(参考)。

对朴素贝叶斯的描述就停在这里,本文的重点是练习 Rust 。

如果您想了解有关该算法的更多信息,这里有一些资源:

  • • Josh Starmer 的视频讲解非常好.

  • • Joel Grus 在《*Data Science from Scratch*》这本书中关于贝叶斯一节的描述是本文实现的启发。

  • • 如果你更喜欢数学的形式化定义, try section 6.6.3 of *The Elements of Statisical Learning.*

  • • 一篇关于算法工作原理的有用文章

朴素贝叶斯分类器的典型应用是垃圾邮件分类器。这就是我们要实现的东西。代码在这:

https://github.com/josht-jpg/shaking-off-the-rust

我们从使用 Cargo 创建一个新的库开始:

cargo new naive_bayes --lib
cd naive_bayes

Tokenization in Rust

我们的分类器会将邮件消息内容作为输入并返回其是否为垃圾邮件的分类。

为了处理我们收到的消息,我们需要对其进行tokenize(分词)。我们的词汇表将是一堆小写的单词,忽略顺序和重复单词。Rust 的 std::collections::HashSet 结构正合适来实现词汇表。

我们将编写分词的函数将需要使用 regex crate。确保在 Cargo.toml 文件中包含以下依赖项:

[dependencies]
regex = "^1.5.4"

tokenize 分词函数:

// lib.rs

// We'll need HashMap later
use std::collections::{HashMap, HashSet};

extern crate regex;
use regex::Regex;

pub fn tokenize(lower_case_text: &str) -> HashSet<&str> {
    Regex::new(r"[a-z0-9']+")
        .unwrap()
        .find_iter(lower_case_text)
        .map(|mat| mat.as_str())
        .collect()
}

此函数使用正则表达式匹配所有数字和小写字母。每当我们遇到不同类型的符号(通常是空格或标点符号)时,我们就会拆分输入并将自上次拆分后遇到的所有数字和字母组合在一起(你可以在这里阅读有关 Rust 正则表达式的更多内容)。也就是说,我们正在识别和分割输入文本中的单词。

构造结构体

使用 struct 来表示消息是很好的方法。struct 将包含消息文本的字符串切片,以及指示消息是否为垃圾邮件的布尔值:

pub struct Message<'a> {
    pub text: &'a str,
    pub is_spam: bool,
}

'a 是声明周期注释。如果你不熟悉生命周期,我推荐你阅读 section 10.3 of The Rust Programming Language Book 。

什么是拉普拉斯平滑?

假设——在我们的训练数据中——单词 fubar 出现在一些非垃圾邮件中,但没有出现在任何垃圾邮件中。此时,朴素贝叶斯分类器把任何包含单词 fubar参考)的消息认定为非垃圾邮件,也就是说该消息是垃圾邮件的概率为 0 。

显然,仅仅因为它还没有发生就给它分配 0 的概率是不合适的。

加入拉普拉斯平滑就是用来解决这个问题的,将一个常数 $\alpha$ 加在每个单词出现的次数统计上。我们来观察一下拉普拉斯平滑常数加入前后,在垃圾邮件中看到单词 w 的概率为:

$$ P(w|S) = \frac{number\ of\ spam\ messages\ containing\ w}{total\ number\ of\ spams} $$

使用拉普拉斯平滑后就是:$$ P(w|S) = \frac{\alpha + number\ of\ spam\ messages\ containing\ w}{2\alpha + total\ number\ of\ spams} $$

具体到分类器结构体就是:

pub struct NaiveBayesClassifier {
    pub alpha: f64,
    pub tokens: HashSet<String>,
    pub token_ham_counts: HashMap<String, i32>,
    pub token_spam_counts: HashMap<String, i32>,
    pub spam_messages_count: i32,
    pub ham_messages_count: i32,
}

NaiveBayesClassifier 的实现将围绕一个 train 方法和一个 predict 方法。

如何训练分类器

train 方法将接收多个 Message ,并循环对每个 Message 进行以下步骤:

  • • 检查邮件是否为垃圾邮件并相应地更新 spam_messages_count 或 ham_messages_count。我们为此创建辅助函数 increment_message_classifications_count

  • • 使用 tokenize 函数将消息分词。

  • • 遍历消息中的每个单词,然后:

  • • 将单词插入词汇表 HashSet ,然后更新token_spam_counts 或 token_ham_counts。我们为此创建辅助函数 increment_token_count

下面是 train 方法的伪代码。如果你愿意,尝试将伪代码转换为 Rust,然后再查看下面的实现。

implementation block for NaiveBayesClassifier {

    train(self, messages) {
        for each message in messages {
            self.increment_message_classifications_count(message)
            
            lowercase_text = to_lowercase(message.text)
            for each token in tokenize(lowercase_text) {
                self.tokens.insert(tokens)
                self.increment_token_count(token, message.is_spam)
            }   
        }
    }

    increment_message_classifications_count(self, message) {
        if message.is_spam {
            self.spam_messages_count = self.spam_messages_count + 1
        } else {
            self.ham_messages_count = self.ham_messages_count + 1
        }
    }

    increment_token_count(&mut self, token, is_spam) {
        if token is not a key of self.token_spam_counts {
            insert record with key=token and value=0 into self.token_spam_counts
        }

        if token is not a key of self.token_ham_counts {
            insert record with key=token and value=0 into self.token_ham_counts
        }

        if is_spam {
            self.token_spam_counts[token] = self.token_spam_counts[token] + 1
        } else {
            self.token_ham_counts[token] = self.token_ham_counts[token] + 1
        }
    }

}

下面是 Rust 的实现:

impl NaiveBayesClassifier {
    pub fn train(&mut self, messages: &[Message]) {
        for message in messages.iter() {
            self.increment_message_classifications_count(message);
            for token in tokenize(&message.text.to_lowercase()) {
                self.tokens.insert(token.to_string());
                self.increment_token_count(token, message.is_spam)
            }
        }
    }

    fn increment_message_classifications_count(&mut self, message: &Message) {
        if message.is_spam {
            self.spam_messages_count += 1;
        } else {
            self.ham_messages_count += 1;
        }
    }

    fn increment_token_count(&mut self, token: &str, is_spam: bool) {
        if !self.token_spam_counts.contains_key(token) {
            self.token_spam_counts.insert(token.to_string(), 0);
        }

        if !self.token_ham_counts.contains_key(token) {
            self.token_ham_counts.insert(token.to_string(), 0);
        }

        if is_spam {
            self.increment_spam_count(token);
        } else {
            self.increment_ham_count(token);
        }
    }

    fn increment_spam_count(&mut self, token: &str) {
        *self.token_spam_counts.get_mut(token).unwrap() += 1;
    }

    fn increment_ham_count(&mut self, token: &str) {
        *self.token_ham_counts.get_mut(token).unwrap() += 1;
    }
}

请注意,在 HashMap 中增加一个值是非常耗时的。新手 Rust 程序员很难理解下面这行代码在做什么:

*self.token_spam_counts.get_mut(token).unwrap() += 1

为了使代码更明确,我创建了 increment_spam_count 和 increment_ham_count 函数。但我对此并不满意——这仍然很麻烦。如果您对更好的方法有建议,请与我联系。

如何使用分类器做预测

predict 方法接手一个字符串切片,返回模型对于该消息是否为垃圾邮件的预测结果。

我们创建两个辅助函数 probabilities_of_message 和 robabilites_of_token 来完成 predict 的任务。

probabilities_of_message returns P(Message|Spam) and P(Message|ham)

probabilities_of_token returns P(Token|Spam) and P(Token|ham)

计算输入消息是垃圾邮件的概率需要将每个单词在垃圾邮件中出现的概率相乘。

概率是介于 0 和 1 之间的浮点数,将许多概率相乘会导致下溢(参考)。这是因为当计算产生的数字小于计算机可以准确存储的数字(请参阅这里这里)。因此,我们将使用对数和指数将任务转换为一系列加法:

$$ \Pi_{i=0}^{n} p_{i}=\exp \left(\sum_{i=0}^{n} \log \left(p_{i}\right)\right) $$

因为对于任何实数 a 和 b, $$ ab = exp(log(ab))=exp(log(a)+log(b)) $$

我将再次先给出 predict 方法的伪代码:

implementation block for NaiveBayesCalssifier {
    /*...*/

    predict(self, text) {
        lower_case_text = to_lowercase(text)
        message_tokens = tokenize(text)
        (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens)
        return prob_if_spam / (prob_if_spam + prob_if_ham)
    }
    
    probabilities_of_message(self, message_tokens) {
        log_prob_if_spam = 0
        log_prob_if_ham = 0

        for each token in self.tokens {
            (prob_if_spam, prob_if_ham) = self.probabilites_of_token(token)

            if message_tokens contains token {
                log_prob_if_spam = log_prob_if_spam + ln(prob_if_spam)
                log_prob_if_ham = log_prob_if_ham + ln(prob_if_ham)
            } else {
                log_prob_if_spam = log_prob_if_spam + ln(1 - prob_if_spam)
                log_prob_if_ham = log_prob_if_ham + ln(1 - prob_if_ham)
            }
        }

        prob_if_spam = exp(log_prob_if_spam)
        prob_if_ham = exp(log_prob_if_ham)

        return (prob_if_spam, prob_if_ham)
    }

    probabilites_of_token(self, token) {
        prob_of_token_spam = (self.token_spam_counts[token] + self.alpha) 
                        / (self.spam_messages_count + 2 * self.alpha)
        
        prob_of_token_ham = (self.token_ham_counts[token] + self.alpha) 
                        / (self.ham_messages_count + 2 * self.alpha)

        return (prob_of_token_spam, prob_of_token_ham)
    }
    
    
}

Rust 的具体实现:

impl NaiveBayesClassifier {

        /*...*/

    pub fn predict(&self, text: &str) -> f64 {
        let lower_case_text = text.to_lowercase();
        let message_tokens = tokenize(&lower_case_text);
        let (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens);

        return prob_if_spam / (prob_if_spam + prob_if_ham);
    }

    fn probabilities_of_message(&self, message_tokens: HashSet<&str>) -> (f64, f64) {
        let mut log_prob_if_spam = 0.;
        let mut log_prob_if_ham = 0.;

        for token in self.tokens.iter() {
            let (prob_if_spam, prob_if_ham) = self.probabilites_of_token(&token);

            if message_tokens.contains(token.as_str()) {
                log_prob_if_spam += prob_if_spam.ln();
                log_prob_if_ham += prob_if_ham.ln();
            } else {
                log_prob_if_spam += (1. - prob_if_spam).ln();
                log_prob_if_ham += (1. - prob_if_ham).ln();
            }
        }

        let prob_if_spam = log_prob_if_spam.exp();
        let prob_if_ham = log_prob_if_ham.exp();

        return (prob_if_spam, prob_if_ham);
    }

    fn probabilites_of_token(&self, token: &str) -> (f64, f64) {
        let prob_of_token_spam = (self.token_spam_counts[token] as f64 + self.alpha)
            / (self.spam_messages_count as f64 + 2. * self.alpha);

        let prob_of_token_ham = (self.token_ham_counts[token] as f64 + self.alpha)
            / (self.ham_messages_count as f64 + 2. * self.alpha);

        return (prob_of_token_spam, prob_of_token_ham);
    }
}

如何测试分类器

让我们对模型做个测试。下面的代码中的样例手动打上了分类标签,然后检查我们的模型是否给出了相同的结果。

检查代码逻辑是很有必要的,你可以将代码粘贴到 lib.rs 文件的底部以检查你的代码是否有效。

// ...lib.rs

pub fn new_classifier(alpha: f64) -> NaiveBayesClassifier {
    return NaiveBayesClassifier {
        alpha,
        tokens: HashSet::new(),
        token_ham_counts: HashMap::new(),
        token_spam_counts: HashMap::new(),
        spam_messages_count: 0,
        ham_messages_count: 0,
    };
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn naive_bayes() {
        let train_messages = [
            Message {
                text: "Free Bitcoin viagra XXX christmas deals 😻😻😻",
                is_spam: true,
            },
            Message {
                text: "My dear Granddaughter, please explain Bitcoin over Christmas dinner",
                is_spam: false,
            },
            Message {
                text: "Here in my garage...",
                is_spam: true,
            },
        ];

        let alpha = 1.;
        let num_spam_messages = 2.;
        let num_ham_messages = 1.;

        let mut model = new_classifier(alpha);
        model.train(&train_messages);

        let mut expected_tokens: HashSet<String> = HashSet::new();
        for message in train_messages.iter() {
            for token in tokenize(&message.text.to_lowercase()) {
                expected_tokens.insert(token.to_string());
            }
        }

        let input_text = "Bitcoin crypto academy Christmas deals";

        let probs_if_spam = [
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "viagra"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "christmas"  (present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "my"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dear"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "please"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "explain"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "over"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dinner"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "here"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "in"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "garage"  (not present)
        ];

        let probs_if_ham = [
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "viagra"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "christmas"  (present)
            (0. + alpha) / (num_ham_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "my"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dear"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "please"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "explain"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "over"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dinner"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "here"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "in"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "garage"  (not present)
        ];

        let p_if_spam_log: f64 = probs_if_spam.iter().map(|p| p.ln()).sum();
        let p_if_spam = p_if_spam_log.exp();

        let p_if_ham_log: f64 = probs_if_ham.iter().map(|p| p.ln()).sum();
        let p_if_ham = p_if_ham_log.exp();

        // P(message | spam) / (P(messge | spam) + P(message | ham)) rounds to 0.97
        assert!((model.predict(input_text) - p_if_spam / (p_if_spam + p_if_ham)).abs() < 0.000001);
    }
}

现在可以通过 cargo test 进行测试,如果你成功通过了测试,你用 Rust 实现的朴素贝叶斯模型没有问题了!

感谢你看到这里。如果您有任何问题或建议,请随时与我们联系。

References

  1. 1. Grus, J. (2019). Data Science from Scratch: First Principles with Python, 2nd edition. O’Reilly Media.

  2. 2. Downey, A. (2021). Think Bayes: Bayesian Statistics in Python, 2nd edition. O’Reilly Media.

  3. 3. Murphy, K. (2012). Machine Learning: A Probabilistic Perspective. MIT Press.

  4. 4. Dhinakaran, V. (2017). Rust Cookbook. Packt.

  5. 5. Ng, A. (2018). Stanford CS229: Lecture 5 - GDA & Naive Bayes.

  6. 6. Burden, R. Faires, J. Burden, A. (2015). Numerical Analysis, 10th edition. Brooks Cole.

  7. 7. *Underflow.* Technopedia.