搜公众号
推荐 原创 视频 Java开发 开发工具 Python开发 Kotlin开发 Ruby开发 .NET开发 服务器运维 开放平台 架构师 大数据 云计算 人工智能 开发语言 其它开发 iOS开发 前端开发 JavaScript开发 Android开发 PHP开发 数据库
Lambda在线 > 努力给自己看 > 朴素贝叶斯算法&应用实例

朴素贝叶斯算法&应用实例

努力给自己看 2019-01-27
举报

朴素贝叶斯

朴素贝叶斯中的朴素是指假设各个特征之间相互独立,不会互相影响,所以称为朴素贝叶斯。正是因为这个假设,使得算法的模型简单且容易理解,虽然牺牲了一点准确性,但是如果模型训练的好,也能得到不错的分类效果。

公式简单推导

下面我们简单看一下公式的推导过程

评测指标

我们得出分类的结果后,怎么来评测我们训练的模型的好与不好呢?我们通常「准确度」「精确率」「召回率」这几个指标来进行判断模型的好坏。下边我们用一个简单的例子来说明这几个指标是怎么计算的。

下面我们看一个表。

/ / 预测类别 预测类别
/ / 科技(35) 财经(35)
实际类别 科技(40) 30 10
实际类别 财经(30) 5 25

表中表示实际上科技类的文章有 40 篇,财经类的有 30 篇,然而预测的结果科技类的有 35 篇,其中 30 篇预测正确了,有 5 篇预测错误了;预测结果财经类的有 35 篇,其中 25 篇预测正确了,10 篇预测错误了。

  • 准确度

表示预测正确的文章数比上总的文章数:(30+25)/(40+30)=78%

  • 精确率

表示每一类预测正确的数量比上预测的该类文章总数量,比如科技类精确率:30/(30+5)=85%

  • 召回率

表示每一类预测正确的数量比上实际该类的总数量,比如科技类:30/40=75%

应用实例

上边我们已经了解了朴素贝叶斯公式及推导过程,下边我们来看一下在实际的新闻分类中的应用。

元数据的准备,我们的元数据是网上找来的一些各类的新闻,这里为了简单,我们只选取了科技、财经和体育三类数量不等的新闻,并且都已知他们的类别。然后通过中文结巴分词

对每篇新闻进行分词。这里我们用到的是gihub上的一个开源的python库,有兴趣的可以了解一下。

下面我们来看一下代码的具体实现。

首先我们先把汉字的文章转成每个词所对应的数字id的形式,方便我们后边的操作和计算。

Convert.py

 
   
   
 
  1. import os

  2. import sys

  3. import random

  4. import re


  5. inputPath = sys.argv[1]

  6. outputFile = sys.argv[2]

  7. #训练集所占百分比

  8. trainPercent = 0.8

  9. wordDict = {}

  10. wordList = []


  11. trainOutputFile = open('%s.train' % outputFile, "w")

  12. testOutputFile = open('%s.test' % outputFile, "w")


  13. for fileName in os.listdir(inputPath):

  14.    tag = 0

  15.    if fileName.find('technology') != -1:

  16.        tag = 1

  17.    elif fileName.find('business') != -1:

  18.        tag = 2

  19.    elif fileName.find('sport') != -1:

  20.        tag = 3


  21.    outFile = trainOutputFile

  22.    rd = random.random()

  23.    if rd >= trainPercent:

  24.        outFile = testOutputFile


  25.    inputFile = open(inputPath+'/'+fileName, "r")

  26.    content = inputFile.read().strip()

  27.    content = content.decode('utf-8', 'ignore')

  28.    content = content.replace('\n', ' ')

  29.    r1 = u'[a-zA-Z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'

  30.    content = re.sub(r1, '', content)

  31.    outFile.write(str(tag)+' ')

  32.    words = content.split(' ')

  33.    for word in words:

  34.        if word not in wordDict:

  35.            wordList.append(word)

  36.            wordDict[word] = len(wordList)


  37.        outFile.write(str(wordDict[word]) + ' ')


  38.    inputFile.close()


  39. trainOutputFile.close()

  40. testOutputFile.close()

朴素贝叶斯实现过程

NB.py

 
   
   
 
  1. #Usage:

  2. #Training: NB.py 1 TrainingDataFile ModelFile

  3. #Testing: NB.py 0 TestDataFile ModelFile OutFile


  4. import sys

  5. import os

  6. import math



  7. DefaultFreq = 0.1

  8. TrainingDataFile = "nb_data.train"

  9. ModelFile = "nb_data.model"

  10. TestDataFile = "nb_data.test"

  11. TestOutFile = "nb_data.out"

  12. ClassFeaDic = {}

  13. ClassFreq = {}

  14. WordDic = {}

  15. ClassFeaProb = {}

  16. ClassDefaultProb = {}

  17. ClassProb = {}


  18. #加载数据

  19. def LoadData():

  20.    i =0

  21.    infile = open(TrainingDataFile, 'r')

  22.    sline = infile.readline().strip()

  23.    while len(sline) > 0:

  24.        pos = sline.find("#")

  25.        if pos > 0:

  26.            sline = sline[:pos].strip()

  27.        words = sline.split(' ')

  28.        if len(words) < 1:

  29.            print("Format error!")

  30.            break

  31.        classid = int(words[0])

  32.        if classid not in ClassFeaDic:

  33.            ClassFeaDic[classid] = {}

  34.            ClassFeaProb[classid] = {}

  35.            ClassFreq[classid]  = 0

  36.        ClassFreq[classid] += 1

  37.        words = words[1:]

  38.        for word in words:

  39.            if len(word) < 1:

  40.                continue

  41.            wid = int(word)

  42.            if wid not in WordDic:

  43.                WordDic[wid] = 1

  44.            if wid not in ClassFeaDic[classid]:

  45.                ClassFeaDic[classid][wid] = 1

  46.            else:

  47.                ClassFeaDic[classid][wid] += 1

  48.        i += 1

  49.        sline = infile.readline().strip()

  50.    infile.close()

  51.    print(i, "instances loaded!")

  52.    print(len(ClassFreq), "classes!", len(WordDic), "words!")


  53. #计算模型

  54. def ComputeModel():

  55.    sum = 0.0

  56.    for freq in ClassFreq.values():

  57.        sum += freq

  58.    for classid in ClassFreq.keys():

  59.        ClassProb[classid] = (float)(ClassFreq[classid])/(float)(sum)

  60.    for classid in ClassFeaDic.keys():

  61.        sum = 0.0

  62.        for wid in ClassFeaDic[classid].keys():

  63.            sum += ClassFeaDic[classid][wid]

  64.        newsum = (float)(sum + 1)

  65.        for wid in ClassFeaDic[classid].keys():

  66.            ClassFeaProb[classid][wid] = (float)(ClassFeaDic[classid][wid]+DefaultFreq)/newsum

  67.        ClassDefaultProb[classid] = (float)(DefaultFreq) / newsum

  68.    return


  69. #保存模型

  70. def SaveModel():

  71.    outfile = open(ModelFile, 'w')

  72.    for classid in ClassFreq.keys():

  73.        outfile.write(str(classid))

  74.        outfile.write(' ')

  75.        outfile.write(str(ClassProb[classid]))

  76.        outfile.write(' ')

  77.        outfile.write(str(ClassDefaultProb[classid]))

  78.        outfile.write(' ' )

  79.    outfile.write('\n')

  80.    for classid in ClassFeaDic.keys():

  81.        for wid in ClassFeaDic[classid].keys():

  82.            outfile.write(str(wid)+' '+str(ClassFeaProb[classid][wid]))

  83.            outfile.write(' ')

  84.        outfile.write('\n')

  85.    outfile.close()


  86. #加载模型

  87. def LoadModel():

  88.    global WordDic

  89.    WordDic = {}

  90.    global ClassFeaProb

  91.    ClassFeaProb = {}

  92.    global ClassDefaultProb

  93.    ClassDefaultProb = {}

  94.    global ClassProb

  95.    ClassProb = {}

  96.    infile = open(ModelFile, 'r')

  97.    sline = infile.readline().strip()

  98.    items = sline.split(' ')

  99.    if len(items) < 6:

  100.        print("Model format error!")

  101.        return

  102.    i = 0

  103.    while i < len(items):

  104.        classid = int(items[i])

  105.        ClassFeaProb[classid] = {}

  106.        i += 1

  107.        if i >= len(items):

  108.            print("Model format error!")

  109.            return

  110.        ClassProb[classid] = float(items[i])

  111.        i += 1

  112.        if i >= len(items):

  113.            print("Model format error!")

  114.            return

  115.        ClassDefaultProb[classid] = float(items[i])

  116.        i += 1

  117.    for classid in ClassProb.keys():

  118.        sline = infile.readline().strip()

  119.        items = sline.split(' ')

  120.        i = 0

  121.        while i < len(items):

  122.            wid  = int(items[i])

  123.            if wid not in WordDic:

  124.                WordDic[wid] = 1

  125.            i += 1

  126.            if i >= len(items):

  127.                print("Model format error!")

  128.                return

  129.            ClassFeaProb[classid][wid] = float(items[i])

  130.            i += 1

  131.    infile.close()

  132.    print(len(ClassProb), "classes!", len(WordDic), "words!")


  133. #预测类别

  134. def Predict():

  135.    global WordDic

  136.    global ClassFeaProb

  137.    global ClassDefaultProb

  138.    global ClassProb


  139.    TrueLabelList = []

  140.    PredLabelList = []

  141.    i =0

  142.    infile = open(TestDataFile, 'r')

  143.    outfile = open(TestOutFile, 'w')

  144.    sline = infile.readline().strip()

  145.    scoreDic = {}

  146.    iline = 0

  147.    while len(sline) > 0:

  148.        iline += 1

  149.        if iline % 10 == 0:

  150.            print(iline," lines finished!\r")

  151.        pos = sline.find("#")

  152.        if pos > 0:

  153.            sline = sline[:pos].strip()

  154.        words = sline.split(' ')

  155.        if len(words) < 1:

  156.            print("Format error!")

  157.            break

  158.        classid = int(words[0])

  159.        TrueLabelList.append(classid)

  160.        words = words[1:]

  161.        for classid in ClassProb.keys():

  162.            scoreDic[classid] = math.log(ClassProb[classid])

  163.        for word in words:

  164.            if len(word) < 1:

  165.                continue

  166.            wid = int(word)

  167.            if wid not in WordDic:

  168.                continue

  169.            for classid in ClassProb.keys():

  170.                if wid not in ClassFeaProb[classid]:

  171.                    scoreDic[classid] += math.log(ClassDefaultProb[classid])

  172.                else:

  173.                    scoreDic[classid] += math.log(ClassFeaProb[classid][wid])

  174.        i += 1

  175.        maxProb = max(scoreDic.values())

  176.        for classid in scoreDic.keys():

  177.            if scoreDic[classid] == maxProb:

  178.                PredLabelList.append(classid)

  179.        sline = infile.readline().strip()

  180.    infile.close()

  181.    outfile.close()

  182.    print(len(PredLabelList),len(TrueLabelList))

  183.    return TrueLabelList,PredLabelList


  184. #计算准确度

  185. def Evaluate(TrueList, PredList):

  186.    accuracy = 0

  187.    i = 0

  188.    while i < len(TrueList):

  189.        if TrueList[i] == PredList[i]:

  190.            accuracy += 1

  191.        i += 1

  192.    accuracy = (float)(accuracy)/(float)(len(TrueList))

  193.    print("Accuracy:",accuracy)


  194. #计算精确率和召回率

  195. def CalPreRec(TrueList,PredList,classid):

  196.    correctNum = 0

  197.    allNum = 0

  198.    predNum = 0

  199.    i = 0

  200.    while i < len(TrueList):

  201.        if TrueList[i] == classid:

  202.            allNum += 1

  203.            if PredList[i] == TrueList[i]:

  204.                correctNum += 1

  205.        if PredList[i] == classid:

  206.            predNum += 1

  207.        i += 1

  208.    return (float)(correctNum)/(float)(predNum),(float)(correctNum)/(float)(allNum)


  209. #main framework

  210. if sys.argv[1] == '1':

  211.    print("start training:")

  212.    LoadData()

  213.    ComputeModel()

  214.    SaveModel()

  215. elif sys.argv[1] == '0':

  216.    print("start testing:")


  217.    LoadModel()

  218.    TList,PList = Predict()

  219.    i = 0

  220.    outfile = open(TestOutFile, 'w')

  221.    while i < len(TList):

  222.        outfile.write(str(TList[i]))

  223.        outfile.write(' ')

  224.        outfile.write(str(PList[i]))

  225.        outfile.write('\n')

  226.        i += 1

  227.    outfile.close()

  228.    Evaluate(TList,PList)

  229.    for classid in ClassProb.keys():

  230.        pre,rec = CalPreRec(TList, PList,classid)

  231.        print("Precision and recall for Class",classid,":",pre,rec)

  232. else:

  233.    print("Usage incorrect!")





如果觉得好看,可以点一下 好看 ,如果觉得对你有一点点帮助,可以赞赏作者一点,还可以推荐和分享给你的朋友

「努力给自己看」

版权声明:本站内容全部来自于腾讯微信公众号,属第三方自助推荐收录。《朴素贝叶斯算法&应用实例》的版权归原作者「努力给自己看」所有,文章言论观点不代表Lambda在线的观点, Lambda在线不承担任何法律责任。如需删除可联系QQ:516101458

文章来源: 阅读原文

相关阅读

关注努力给自己看微信公众号

努力给自己看微信公众号:gh_43402edb1aba

努力给自己看

手机扫描上方二维码即可关注努力给自己看微信公众号

努力给自己看最新文章

精品公众号随机推荐

举报