vlambda博客
学习文章列表

R语言 CART算法和C4.5算法(决策树)


关注CSDN博客:程志伟的博客


R版本:3.4.4



还需要安装java环境,下载jdk,配置环境变量。


draw.tree函数:绘制树状图


J48函数:实现C4.5算法


maptree包:提供draw.tree函数


mvpart包:提供数据集car.test.frame


post函数:对rpart()结果绘制演示图


prune.rpart():对rpart()的结果进行剪枝


rpart包:提供函数rpart()、prune.rpart()、post()


rpart.plot包提供rpart.plot函数,绘制决策时


RWeka包:提供函数J48


sampling包:strata函数用于分层抽样


 


#设置工作路径


> setwd('G:\\R语言\\大三下半年\\数据挖掘:R语言实战\\')

Warning message:

file ‘.RData’ has magic number 'RDX3'

  Use of save versions prior to 2 is deprecated 

> library(mvpart)

> data("car.test.frame")

> head(car.test.frame)

                 Price   Country Reliability Mileage  Type Weight Disp.  HP

Eagle Summit 4    8895       USA           4      33 Small   2560    97 113

Ford Escort   4   7402       USA           2      33 Small   2345   114  90

Ford Festiva 4    6319     Korea           4      37 Small   1845    81  63

Honda Civic 4     6635 Japan/USA           5      32 Small   2260    91  92

Mazda Protege 4   6599     Japan           5      32 Small   2440   113 103

Mercury Tracer 4  8672    Mexico           4      26 Small   2285    97  82

> car.test.frame$Mileage <- 100*4.546/(1.6*car.test.frame$Mileage)

> names(car.test.frame) <- c("价格","产地","可靠性","油耗","类型","车重","发动机功率","净马力")

> head(car.test.frame)

                            价格      产地   可靠性   油耗    类型 车重 发动机功率 净马力

Eagle Summit 4   8895       USA      4  8.609848 Small 2560         97    113

Ford Escort   4  7402       USA      2  8.609848 Small 2345        114     90

Ford Festiva 4   6319     Korea      4  7.679054 Small 1845         81     63

Honda Civic 4    6635 Japan/USA      5  8.878906 Small 2260         91     92

Mazda Protege 4  6599     Japan      5  8.878906 Small 2440        113    103

Mercury Tracer 4 8672    Mexico      4 10.927885 Small 2285         97     82


#str()函数显示该数据集有60行8列

> str(car.test.frame)

'data.frame':    60 obs. of  8 variables:

 $ 价格      : int  8895 7402 6319 6635 6599 8672 7399 7254 9599 5866 ...

 $ 产地      : Factor w/ 8 levels "France","Germany",..: 8 8 5 4 3 6 4 5 3 3 ...

 $ 可靠性    : int  4 2 4 5 5 4 5 1 5 NA ...

 $ 油耗      : num  8.61 8.61 7.68 8.88 8.88 ...

 $ 类型      : Factor w/ 6 levels "Compact","Large",..: 4 4 4 4 4 4 4 4 4 4 ...

 $ 车重      : int  2560 2345 1845 2260 2440 2285 2275 2350 2295 1900 ...

 $ 发动机功率: int  97 114 81 91 113 97 97 98 109 73 ...

 $ 净马力    : int  113 90 63 92 103 82 90 74 90 73 ...



> summary(car.test.frame)

      价格              产地        可靠性           油耗             类型   

 Min.   : 5866   USA      :26   Min.   :1.000   Min.   : 7.679   Compact:15  

 1st Qu.: 9932   Japan    :19   1st Qu.:2.000   1st Qu.:10.523   Large  : 3  

 Median :12216   Japan/USA: 7   Median :3.000   Median :12.353   Medium :13  

 Mean   :12616   Korea    : 3   Mean   :3.388   Mean   :11.962   Small  :13  

 3rd Qu.:14933   Germany  : 2   3rd Qu.:5.000   3rd Qu.:13.530   Sporty : 9  

 Max.   :24760   France   : 1   Max.   :5.000   Max.   :15.785   Van    : 7  

                 (Other)  : 2   NA's   :11                                   

      车重        发动机功率        净马力     

 Min.   :1845   Min.   : 73.0   Min.   : 63.0  

 1st Qu.:2571   1st Qu.:113.8   1st Qu.:101.5  

 Median :2885   Median :144.5   Median :111.5  

 Mean   :2901   Mean   :152.1   Mean   :122.3  

 3rd Qu.:3231   3rd Qu.:180.0   3rd Qu.:142.8  

 Max.   :3855   Max.   :305.0   Max.   :225.0  

                                               

> #2. 数据预处理

> #下面我们着重看油耗变量,因为在以下的建模过程中,将以油耗作为目标变量

> #一个数据集来分别构建出以离散型和连续型变量为各自目标变量的分类树和回归树,考虑添加一列变量——分组油耗,即将油耗变量划分为三个组别,A:11.6~15.8个油、B:9~11.6个油、C:7.7~9个油,成为含有3个水平的A、B、C的因子变量。

> Group_Mileage=matrix(0,60,1) #设矩阵Group_Mileage用于存放新变量

> Group_Mileage[which(car.test.frame$"油耗">=11.6)]="A"

> Group_Mileage[which(car.test.frame$"油耗"<=9)]="C"   #将油耗在7.7~9区间的样本Group_Mileage值取C

> Group_Mileage[which(Group_Mileage==0)]="B" #将油耗不在组A、C的样本Group_Mileage值取B

> car.test.frame$"分组油耗"=Group_Mileage  #在数据集中添加新变量分组油耗

> car.test.frame[1:10,c(4,9)]

                      油耗 分组油耗

Eagle Summit 4    8.609848        C

Ford Escort   4   8.609848        C

Ford Festiva 4    7.679054        C

Honda Civic 4     8.878906        C

Mazda Protege 4   8.878906        C

Mercury Tracer 4 10.927885        B

Nissan Sentra 4   8.609848        C

Pontiac LeMans 4 10.147321        B

Subaru Loyale 4  11.365000        B

Subaru Justy 3    8.356618        C

> a=round(1/4*sum(car.test.frame$"分组油耗"=="A"))

> b=round(1/4*sum(car.test.frame$"分组油耗"=="B"))

> c=round(1/4*sum(car.test.frame$"分组油耗"=="C"))

> #分别计算A、B、C组中应抽取测试集样本数,记为a、b、c

> a;b;c

[1] 9

[1] 4

[1] 2


 


#使用strata()函数对car.test.frame中的“分组油耗”变量进行分层抽样

> library(sampling)

> sub=strata(car.test.frame, stratanames="分组油耗", size=c(c,b,a), method="srswor")


> sub

   分组油耗 ID_unit      Prob Stratum

4         C       4 0.2222222       1

7         C       7 0.2222222       1

6         B       6 0.2500000       2

8         B       8 0.2500000       2

17        B      17 0.2500000       2

21        B      21 0.2500000       2

20        A      20 0.2571429       3

38        A      38 0.2571429       3

42        A      42 0.2571429       3

47        A      47 0.2571429       3

51        A      51 0.2571429       3

52        A      52 0.2571429       3

56        A      56 0.2571429       3

57        A      57 0.2571429       3

60        A      60 0.2571429       3

> Train_Car=car.test.frame[-sub$ID_unit,] #生成训练集

> Test_Car=car.test.frame[sub$ID_unit,]   #生成测试集

> nrow(Train_Car);nrow(Test_Car) #显示训练集、测试集行数

[1] 45

[1] 15



> ##################应用案例##################

> library('rpart')


载入程辑包:‘rpart’


The following object is masked _by_ ‘.GlobalEnv’:


    car.test.frame


The following objects are masked from ‘package:mvpart’:


    meanvar, na.rpart, path.rpart, plotcp, post, printcp, prune, prune.rpart,

    rpart, rpart.control, rsq.rpart, snip.rpart, xpred.rpart


> #1. 对油耗变量建立回归树——数字结果


#按照公式对训练集构建回归树method="anova"

> formula_Car_Reg=油耗~价格+产地+可靠性+类型+车重+发动机功率+净马力 #设定模型公式

> rp_Car_Reg=rpart(formula_Car_Reg,Train_Car,method="anova")


> print(rp_Car_Reg)   #导出回归树基本信息

n= 45 


node), split, n, deviance, yval

      * denotes terminal node


1) root 45 186.172800 11.845280  

  2) 发动机功率< 134 19  32.628130  9.977205 *

  3) 发动机功率>=134 26  38.786870 13.210420  

    6) 价格< 11522 7   3.003835 11.877100 *

    7) 价格>=11522 19  18.754370 13.701630 *

#按照节点层次以不同缩进量列出,并在每条节点信息后以星号*标示出是否为叶节点。



> printcp(rp_Car_Reg)  #导出回归树的cp表格


Regression tree:

rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova")


Variables actually used in tree construction:

[1] 发动机功率 价格      


Root node error: 186.17/45 = 4.1372


n= 45 


        CP nsplit rel error  xerror     xstd

1 0.616405      0   1.00000 1.04401 0.182755

2 0.091467      1   0.38360 0.46353 0.075249

3 0.010000      2   0.29213 0.42704 0.063786


#获取决策树rp_Car_Reg详细信息,Variable importance变量的最要程度、improve对分支的提升程度

> summary(rp_Car_Reg)

Call:

rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova")

  n= 45 


          CP nsplit rel error    xerror       xstd

1 0.61640483      0 1.0000000 1.0440087 0.18275540

2 0.09146694      1 0.3835952 0.4635295 0.07524941

3 0.01000000      2 0.2921282 0.4270404 0.06378585


Variable importance

发动机功率       车重       价格     净马力       类型       产地 

        26         20         14         14         13         12 


Node number 1: 45 observations,    complexity param=0.6164048

  mean=11.84528, MSE=4.137174 

  left son=2 (19 obs) right son=3 (26 obs)

  Primary splits:

      发动机功率 < 134    to the left,  improve=0.6164048, (0 missing)

      价格       < 9446.5 to the left,  improve=0.5563919, (0 missing)

      车重       < 2567.5 to the left,  improve=0.5509927, (0 missing)

      类型       splits as  RRRLRR,     improve=0.4392540, (0 missing)

      净马力     < 109    to the left,  improve=0.3982844, (0 missing)

  Surrogate splits:

      车重   < 2747.5 to the left,  agree=0.889, adj=0.737, (0 split)

      净马力 < 109    to the left,  agree=0.800, adj=0.526, (0 split)

      类型   splits as  RRRLRR,     agree=0.778, adj=0.474, (0 split)

      价格   < 9446.5 to the left,  agree=0.756, adj=0.421, (0 split)

      产地   splits as  LLLLR-RR,   agree=0.756, adj=0.421, (0 split)


Node number 2: 19 observations

  mean=9.977205, MSE=1.71727 


Node number 3: 26 observations,    complexity param=0.09146694

  mean=13.21042, MSE=1.491803 

  left son=6 (7 obs) right son=7 (19 obs)

  Primary splits:

      价格       < 11522  to the left,  improve=0.4390315, (0 missing)

      车重       < 3087.5 to the left,  improve=0.3622234, (0 missing)

      类型       splits as  LRL-RR,     improve=0.3121080, (0 missing)

      发动机功率 < 185.5  to the left,  improve=0.1511378, (0 missing)

      净马力     < 148.5  to the left,  improve=0.1511378, (0 missing)

  Surrogate splits:

      车重       < 2757.5 to the left,  agree=0.846, adj=0.429, (0 split)

      产地       splits as  --RLL-RR,   agree=0.808, adj=0.286, (0 split)

      类型       splits as  LRR-RR,     agree=0.808, adj=0.286, (0 split)

      净马力     < 103.5  to the left,  agree=0.808, adj=0.286, (0 split)

      发动机功率 < 142    to the left,  agree=0.769, adj=0.143, (0 split)


Node number 6: 7 observations

  mean=11.8771, MSE=0.4291194 


Node number 7: 19 observations

  mean=13.70163, MSE=0.9870723 


> #下面我们尝试改变rpart()函数的若干参数值,minsplit=10,将分支包含最小样本数minsplit从默认值20改为10,新的回归树记为rp_Car_Reg1。

> rp_Car_Reg1=rpart(formula_Car_Reg,Train_Car,method="anova",minsplit=10)


> print(rp_Car_Reg1) #导出回归树基本信息

n= 45 


node), split, n, deviance, yval

      * denotes terminal node


 1) root 45 186.172800 11.845280  

   2) 发动机功率< 134 19  32.628130  9.977205  

     4) 价格< 9504.5 8   2.649246  8.582424 *

     5) 价格>=9504.5 11   3.096802 10.991590 *

   3) 发动机功率>=134 26  38.786870 13.210420  

     6) 价格< 11522 7   3.003835 11.877100 *

     7) 价格>=11522 19  18.754370 13.701630  

      14) 类型=Compact,Medium 12   2.880021 13.081890 *

      15) 类型=Large,Sporty,Van 7   3.364168 14.764060 *

> printcp(rp_Car_Reg1)  #导出回归树的cp表格


Regression tree:

rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova", 

    minsplit = 10)


Variables actually used in tree construction:

[1] 发动机功率 价格       类型      


Root node error: 186.17/45 = 4.1372


n= 45 


        CP nsplit rel error  xerror     xstd

1 0.616405      0  1.000000 1.04535 0.183175

2 0.144393      1  0.383595 0.46279 0.074342

3 0.091467      2  0.239202 0.36798 0.080715

4 0.067197      3  0.147735 0.38852 0.111059

5 0.010000      4  0.080538 0.32070 0.111356


 


#cp值表示可以使模型拟合程度提高的节点,剪去不重要的分支


> rp_Car_Reg2=rpart(formula_Car_Reg,Train_Car,method="anova",cp=0.1)

> #将CP值从默认的0.01改为0.1,新的回归树记为rp_Car_Reg2

> print(rp_Car_Reg2) #导出回归树基本信息

n= 45 


node), split, n, deviance, yval

      * denotes terminal node


1) root 45 186.17280 11.845280  

  2) 发动机功率< 134 19  32.62813  9.977205 *

  3) 发动机功率>=134 26  38.78687 13.210420 *

> printcp(rp_Car_Reg2)  #导出回归树的cp表格


Regression tree:

rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova", 

    cp = 0.1)


Variables actually used in tree construction:

[1] 发动机功率


Root node error: 186.17/45 = 4.1372


n= 45 


      CP nsplit rel error  xerror     xstd

1 0.6164      0    1.0000 1.07992 0.189818

2 0.1000      1    0.3836 0.47707 0.078252

> #剪枝函数也可以实现同样的效果

> rp_Car_Reg3=prune.rpart(rp_Car_Reg,cp=0.1)

> print(rp_Car_Reg3)

n= 45 


node), split, n, deviance, yval

      * denotes terminal node


1) root 45 186.17280 11.845280  

  2) 发动机功率< 134 19  32.62813  9.977205 *

  3) 发动机功率>=134 26  38.78687 13.210420 *

> printcp(rp_Car_Reg3)


Regression tree:

rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova")


Variables actually used in tree construction:

[1] 发动机功率


Root node error: 186.17/45 = 4.1372


n= 45 


      CP nsplit rel error  xerror     xstd

1 0.6164      0    1.0000 1.04401 0.182755

2 0.1000      1    0.3836 0.46353 0.075249

> #对所生成树的大小也可以通过深度函数maxdepth来控制

> rp_Car_Reg4=rpart(formula_Car_Reg,Train_Car,method="anova",maxdepth=1)

> print(rp_Car_Reg4)

n= 45 


node), split, n, deviance, yval

      * denotes terminal node


1) root 45 186.17280 11.845280  

  2) 发动机功率< 134 19  32.62813  9.977205 *

  3) 发动机功率>=134 26  38.78687 13.210420 *

> printcp(rp_Car_Reg4)


Regression tree:

rpart(formula = formula_Car_Reg, data = Train_Car, method = "anova", 

    maxdepth = 1)


Variables actually used in tree construction:

[1] 发动机功率


Root node error: 186.17/45 = 4.1372


n= 45 


      CP nsplit rel error  xerror     xstd

1 0.6164      0    1.0000 1.09740 0.191557

2 0.0100      1    0.3836 0.52651 0.084732


 



#2. 对油耗变量建立回归树——树形结果

> rp_Car_Plot=rpart(formula_Car_Reg,Train_Car,method="anova",minsplit=10)

> #设置minsplit为10,新的回归树记为rp_Car_Plot

> print(rp_Car_Plot)

n= 45 


node), split, n, deviance, yval

      * denotes terminal node


 1) root 45 186.172800 11.845280  

   2) 发动机功率< 134 19  32.628130  9.977205  

     4) 价格< 9504.5 8   2.649246  8.582424 *

     5) 价格>=9504.5 11   3.096802 10.991590 *

   3) 发动机功率>=134 26  38.786870 13.210420  

     6) 价格< 11522 7   3.003835 11.877100 *

     7) 价格>=11522 19  18.754370 13.701630  

      14) 类型=Compact,Medium 12   2.880021 13.081890 *

      15) 类型=Large,Sporty,Van 7   3.364168 14.764060 *

> library(rpart.plot)

> rpart.plot(rp_Car_Plot)            #绘制决策树

> #相比于数字结果,从树状图中我们可以更清晰的看到模型对于目标变量的预测过程



> rpart.plot(rp_Car_Plot,type=4)     #更改type参数为类型4,绘制决策树

> #每一分支的取值范围也被标示出来,而且在每个节点的油耗的预测值也被标出

> #当树的分支较多时,我们还可以选择设置“分支”参数branch=1来获得垂直枝干形状的决策

> #树以减少图形所占空间,使得树状图的枝干不再显得杂乱无章,更方便查看和分析。


> rpart.plot(rp_Car_Plot,type=4,branch=1)

> #参数fallen.leaves设置为TRUE,即表示将所有叶节点一致的摆放在树的最下端

> rpart.plot(rp_Car_Plot,type=4,branch=1,fallen.leaves=TRUE)

> library(maptree)

载入需要的程辑包:cluster

> draw.tree(rp_Car_Plot, col=rep(1,8), nodeinfo=TRUE)#利用draw.tree()绘制决策树

R语言 CART算法和C4.5算法(决策树)

> plot(rp_Car_Plot,uniform=TRUE,main="plot: Regression TRUE")

> text(rp_Car_Plot,use.n=TRUE,all=TRUE)#用plot()直接绘图,并在图中添加相关文字信息

R语言 CART算法和C4.5算法(决策树)


> post(rp_Car_Plot,file="") #用post()函数绘制决策树

R语言 CART算法和C4.5算法(决策树)


> #3. 对分组油耗变量建立分类树

> formula_Car_Cla=分组油耗~价格+产地+可靠性+类型+车重+发动机功率+净马力

> rp_Car_Cla=rpart(formula_Car_Cla, Train_Car, method="class", minsplit=5)

> #按公式formula_Car_Cla对训练集创建分类树

> print(rp_Car_Cla)

n= 45 


node), split, n, loss, yval, (yprob)

      * denotes terminal node


 1) root 45 19 A (0.57777778 0.26666667 0.15555556)  

   2) 发动机功率>=134 26  2 A (0.92307692 0.07692308 0.00000000)  

     4) 价格>=11222 20  0 A (1.00000000 0.00000000 0.00000000) *

     5) 价格< 11222 6  2 A (0.66666667 0.33333333 0.00000000)  

      10) 发动机功率< 152 4  0 A (1.00000000 0.00000000 0.00000000) *

      11) 发动机功率>=152 2  0 B (0.00000000 1.00000000 0.00000000) *

   3) 发动机功率< 134 19  9 B (0.10526316 0.52631579 0.36842105)  

     6) 价格>=9504.5 11  2 B (0.18181818 0.81818182 0.00000000) *

     7) 价格< 9504.5 8  1 C (0.00000000 0.12500000 0.87500000) *

> #以上输出结果与回归树类似,不同之处仅在于每个节点的预测值不再是具体数值,而是A、B、C,即分组油耗的三个取值水平。

> rpart.plot(rp_Car_Cla, type=4, fallen.leaves=TRUE)  #对rp_Car_Cla绘制分类树


#发动机高于134的且价格高于1万美元的车,属于A类,高油耗;最左边的分支,反之属于C类,最右侧分支。


#4. 对测试集Test_Car预测目标变量

> pre_Car_Cla=predict(rp_Car_Cla,Test_Car,type="class")

> #对测试集Test_Car中观测样本中的分组油耗指标进行预测

> pre_Car_Cla                                                          #显示预测结果

       Honda Civic 4      Nissan Sentra 4     Mercury Tracer 4     Pontiac LeMans 4 

                   C                    C                    C                    C 

          Ford Probe       Plymouth Laser       Nissan 240SX 4      Acura Legend V6 

                   B                    B                    A                    A 

    Eagle Premier V6     Nissan Maxima V6    Buick Le Sabre V6 Chevrolet Caprice V8 

                   A                    A                    A                    A 

    Ford Aerostar V6         Mazda MPV V6         Nissan Van 4 

                   A                    A                    A 

Levels: A B C

> (p=sum(as.numeric(pre_Car_Cla!=Test_Car$"分组油耗"))/nrow(Test_Car)) #计算错误率

[1] 0.1333333

> table(Test_Car$"分组油耗", pre_Car_Cla)                              #获取混淆矩阵

   pre_Car_Cla

    A B C

  A 9 0 0

  B 0 2 2

  C 0 0 2



 ########################    C4.5应用             ###################



> #C4.5算法仅适用离散变量,即构建分类树,因此这里我们就继续沿用上面的数据集来对分组油

> #耗指标进行建树。需要说明的是,用于实现C4.5算法的核心函数J48()对中文识别不太完善,因此我们将使用原

> #英文数据集中的变量名称。#“价格(Price)”、产地(Country)、可靠性(Reliability)、英里数(Mileage)、类型(Type)、

> #车重(Weight)、发动机功率(Disp.),以及净马力(HP),分组油耗(Oil_Consumption)。

> #install.packages("rJava")

> library(RWeka)

> names(Train_Car)=c("Price","Country","Reliability","Mileage","Type","Weight","Disp.",

+                    "HP","Oil_Consumption") #更改为英文变量名

> Train_Car$Oil_Consumption=as.factor(Train_Car$Oil_Consumption)

> #将分组哟好的变量类型改为因子型,使J48()函数可识别

> formula=Oil_Consumption~Price+Country+Reliability+Type+Weight+Disp.+HP

> C45_0=J48(formula,Train_Car)    #在默认参数去之下,构建分类树模型C45_0

> C45_0

J48 pruned tree

------------------


Price <= 9410: C (7.0/1.0)

Price > 9410

|   Disp. <= 132: B (6.0)

|   Disp. > 132

|   |   Price <= 10989

|   |   |   Reliability <= 1: B (2.0)

|   |   |   Reliability > 1: A (4.0/1.0)

|   |   Price > 10989: A (17.0)


Number of Leaves  :     5


Size of the tree :     9


> #共计10个叶节点,15个节点,最后括号中的数字表示有多少观测样本被归入该分支,且其中有

> #几个是被错分的。

> summary(C45_0)


=== Summary ===


Correctly Classified Instances          34               94.4444 %

Incorrectly Classified Instances         2                5.5556 %

Kappa statistic                          0.9045

Mean absolute error                      0.0595

Root mean squared error                  0.1725

Relative absolute error                 15.067  %

Root relative squared error             39.0042 %

Total Number of Instances               36     


=== Confusion Matrix ===


  a  b  c   <-- classified as

 20  0  0 |  a = A

  1  8  1 |  b = B

  0  0  6 |  c = C

> #下面我们通过control参数控制分类树的生成过程。参数M,即对每个叶节点设置最小观测样本

> #量来对树进行剪枝。我们知道,M的默认值为2,现在将其取值为3来减去若干所含样本量较小

> #的分支。

> C45_1=J48(formula,Train_Car,control=Weka_control(M=3))

> #取control参数的M值为3,构建分类树模型C45_1

> C45_1

J48 pruned tree

------------------


Price <= 9410: C (7.0/1.0)

Price > 9410

|   Disp. <= 132: B (6.0)

|   Disp. > 132: A (23.0/3.0)


Number of Leaves  :     3


Size of the tree :     5


> plot(C45_1)                           #对C45_1绘制分类树