在机器学习领域,支持向量机SVM(Support Vector Machine)是一个有监督的学习模型,通常用来进行模式识别、分类、以及回归分析。SVM涉及的知识面非常广,目标函数、优化过程、并行方法、算法收敛性、样本复杂度等。

简介

从事渔业生产有经验的从业者可通过观察水色变化调控水质,以维持养殖水体生态系统中浮游植物、微生物类、浮游动物等合理的动态平衡。由于这些多是通过经验和肉眼观察进行判断,存在主观性引起的观察性偏倚,使观察结果的可比性、可重复性降低,不易推广应用。当前,数字图像处理技术为计算机监控技术在水产养殖业的应用提供更大的空间。在水质在线监测方面,数字图像处理技术是基于计算机视觉,以专家经验为基础,对池塘水色进行优劣分级,达到对池塘水色的准确快速判别。

水色分类

水色 水质类别
浅绿色(清水或浊水) 1
灰蓝色 2
黄褐色 3
茶褐色(姜黄、茶褐、红褐、褐中带绿等) 4
绿色(黄绿、油绿、蓝绿、墨绿、绿中带褐等) 5

标准条件下拍摄的水样图像,接下来进行特征提取,采用颜色矩来提取水样图像的特征,水色图像特征与相应的水色类别的部分数据,如下:

数据概览

mdata <- read.csv('http://data.galaxystatistics.com/blog_data/svm/moment.csv',header=T)
class(mdata)
## [1] "data.frame"
dim(mdata)
## [1] 203  11
head(mdata)
##   类别 序号 R通道一阶矩 G通道一阶矩 B通道一阶矩 R通道二阶矩 G通道二阶矩
## 1    1    1   0.5828229   0.5437737   0.2528287 0.014192030 0.016143875
## 2    1   10   0.6416595   0.5706572   0.2137280 0.015438840 0.011177918
## 3    1   11   0.6036844   0.5767189   0.2822538 0.008658572 0.007074807
## 4    1   12   0.5897057   0.5937430   0.2522425 0.007908293 0.005940868
## 5    1   13   0.5910962   0.5920930   0.2535949 0.007448469 0.006494667
## 6    1   14   0.5886801   0.5696339   0.3189053 0.007527690 0.005046087
##   B通道二阶矩  R通道三阶矩  G通道三阶矩  B通道三阶矩
## 1 0.041075252 -0.012643137 -0.016090364 -0.041536239
## 2 0.013707795  0.009727136 -0.003723814 -0.003779448
## 3 0.012203640 -0.004694985 -0.002570890 -0.009450531
## 4 0.010568364  0.003303400 -0.003416659 -0.005273416
## 5 0.012151602  0.000496116 -0.002235644 -0.005095575
## 6 0.008386259 -0.003529253  0.001746734 -0.005790924

SVM预测模型输入变量

序号 变量名称 变量描述 取值范围
1 R通道一阶矩 水样图像在R颜色通道的一阶矩 0~1
2 G通道一阶矩 水样图像在G颜色通道的一阶矩 0~1
3 B通道一阶矩 水样图像在B颜色通道的一阶矩 0~1
4 R通道二阶矩 水样图像在R颜色通道的二阶矩 0~1
5 G通道二阶矩 水样图像在G颜色通道的二阶矩 0~1
6 B通道二阶矩 水样图像在B颜色通道的二阶矩 0~1
7 R通道三阶矩 水样图像在R颜色通道的三阶矩 -1~1
8 G通道三阶矩 水样图像在G颜色通道的三阶矩 -1~1
9 B通道三阶矩 水样图像在B颜色通道的三阶矩 -1~1
10 水质类别 不同类别能表征水中浮游植物的种类和多少 1,2,3,4,5

模型

图形转换为数据后,R语言建模,建模之前,我来介绍下R语言的SVM用法。使用SVM需安装e1071包。

第一种简单方式建模:

svm(formula, data= NULL, subset, na.action = na.omit , scale= TRUE) 

第二种根据所给的数据建模:

svm(x, y = NULL, scale = TRUE, type = NULL, kernel = "radial",degree = 3, gamma = if (is.vector(x)) 1 else 1 / ncol(x),coef0 = 0, cost = 1, nu = 0.5, subset, na.action = na.omit)

一个经验性的结论是,在利用svm()函数建立支持向量机模型时,使用标准化后的数据建立的模型效果更好。

names(mdata) <- c('class', 'id', 'R1', 'G1', 'B1', 'R2', 'G2', 'B2', 'R3', 'G3', 'B3')
head(mdata)
##   class id        R1        G1        B1          R2          G2
## 1     1  1 0.5828229 0.5437737 0.2528287 0.014192030 0.016143875
## 2     1 10 0.6416595 0.5706572 0.2137280 0.015438840 0.011177918
## 3     1 11 0.6036844 0.5767189 0.2822538 0.008658572 0.007074807
## 4     1 12 0.5897057 0.5937430 0.2522425 0.007908293 0.005940868
## 5     1 13 0.5910962 0.5920930 0.2535949 0.007448469 0.006494667
## 6     1 14 0.5886801 0.5696339 0.3189053 0.007527690 0.005046087
##            B2           R3           G3           B3
## 1 0.041075252 -0.012643137 -0.016090364 -0.041536239
## 2 0.013707795  0.009727136 -0.003723814 -0.003779448
## 3 0.012203640 -0.004694985 -0.002570890 -0.009450531
## 4 0.010568364  0.003303400 -0.003416659 -0.005273416
## 5 0.012151602  0.000496116 -0.002235644 -0.005095575
## 6 0.008386259 -0.003529253  0.001746734 -0.005790924

支持向量机

#数据分割
set.seed(1234)#设置随机种子
#定义序列ind,随机抽取1和2,1的个数占80%,2的个数占20%
ind <- sample(2, nrow(mdata), replace=TRUE, prob=c(0.8, 0.2))
traindata <- mdata[ind==1,] # 训练数据
# head(traindata)
# head(traindata[-2])

testdata <- mdata[ind==2,]  # 测试数据
# head(testdata)
# head(testdata[-2])

#将class列转换为factor类型
traindata <- transform(traindata, class=as.factor(class))
testdata <- transform(testdata, class=as.factor(class))
##支持向量机分类模型构建
library(e1071)
#利用svm建立支持向量机分类模型
svm.model <- svm(class~., traindata[,-2])
summary(svm.model)
## 
## Call:
## svm(formula = class ~ ., data = traindata[, -2])
## 
## 
## Parameters:
##    SVM-Type:  C-classification 
##  SVM-Kernel:  radial 
##        cost:  1 
##       gamma:  0.1111111 
## 
## Number of Support Vectors:  119
## 
##  ( 31 26 41 16 5 )
## 
## 
## Number of Classes:  5 
## 
## Levels: 
##  1 2 3 4 5

通过summary函数可以得到关于模型的相关信息。

  • SVM-Type项目说明本模型的类别为C分类器模型。

  • SVM-Kernel项目说明本模型所使用的核函数为高斯内积函数且核函数中参数gamma的取值为0.11。

  • cost项目说明本模型确定的约束违反成本为l。

  • Number of Support Vectors (模型找到了119个支持向量):第一类包含有31个支持向量,第二类包含有26个支持向量,第三类包含41个支持向量,第四类包含有16个支持向量,第三类包含5个支持向量。

  • Levels: 说明模型中的类别分别为1、2、3、4、5。

建立混淆矩阵(交叉矩阵)

# 训练集
# 建立预测交叉矩阵
confusion.train.svm <- table(traindata$class,predict(svm.model,traindata[,-2],type="class"))
# 预测正确率
accuracy.train.svm <- sum(diag(confusion.train.svm))/sum(confusion.train.svm)
# confusion.train.svm
# accuracy.train.svm
list('训练集预测混淆矩阵'=confusion.train.svm, '训练集预测正确率'=accuracy.train.svm)
## $训练集预测混淆矩阵
##    
##      1  2  3  4  5
##   1 41  0  2  0  0
##   2  0 37  0  0  0
##   3  1  0 61  0  0
##   4  0  0  2 15  0
##   5  1  0  0  0  4
## 
## $训练集预测正确率
## [1] 0.9634146
# 测试集
# 建立预测交叉矩阵
confusion.test.svm <- table(testdata$class,predict(svm.model,testdata[,-2],type="class"))
# 预测正确率
accuracy.test.svm <- sum(diag(confusion.test.svm))/sum(confusion.test.svm)
list('测试集预测混淆矩阵'=confusion.test.svm, '测试集预测正确率'=accuracy.test.svm)
## $测试集预测混淆矩阵
##    
##      1  2  3  4  5
##   1  8  0  0  0  0
##   2  0  7  0  0  0
##   3  0  0 16  0  0
##   4  0  1  2  4  0
##   5  0  0  0  0  1
## 
## $测试集预测正确率
## [1] 0.9230769

随机森林

library(randomForest)
randomForest.model <- randomForest(class~., traindata[,-2])
summary(randomForest.model)
##                 Length Class  Mode     
## call               3   -none- call     
## type               1   -none- character
## predicted        164   factor numeric  
## err.rate        3000   -none- numeric  
## confusion         30   -none- numeric  
## votes            820   matrix numeric  
## oob.times        164   -none- numeric  
## classes            5   -none- character
## importance         9   -none- numeric  
## importanceSD       0   -none- NULL     
## localImportance    0   -none- NULL     
## proximity          0   -none- NULL     
## ntree              1   -none- numeric  
## mtry               1   -none- numeric  
## forest            14   -none- list     
## y                164   factor numeric  
## test               0   -none- NULL     
## inbag              0   -none- NULL     
## terms              3   terms  call
randomForest.model
## 
## Call:
##  randomForest(formula = class ~ ., data = traindata[, -2]) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##         OOB estimate of  error rate: 9.15%
## Confusion matrix:
##    1  2  3  4 5 class.error
## 1 36  3  4  0 0  0.16279070
## 2  1 35  1  0 0  0.05405405
## 3  2  0 60  0 0  0.03225806
## 4  0  0  1 16 0  0.05882353
## 5  1  0  0  2 2  0.60000000
#训练集
confusion.train.randomForest <- table(traindata$class,predict(randomForest.model,traindata[,-2],type="class"))
accuracy.train.randomForest <- sum(diag(confusion.train.randomForest))/sum(confusion.train.randomForest)
# confusion.train.randomForest
# accuracy.train.randomForest
list('训练集预测混淆矩阵'=confusion.train.randomForest, '训练集预测正确率'=accuracy.train.randomForest)
## $训练集预测混淆矩阵
##    
##      1  2  3  4  5
##   1 43  0  0  0  0
##   2  0 37  0  0  0
##   3  0  0 62  0  0
##   4  0  0  0 17  0
##   5  0  0  0  0  5
## 
## $训练集预测正确率
## [1] 1
#测试集
confusion.test.randomForest <-table(testdata$class,predict(randomForest.model,testdata[,-2],type="class"))
accuracy.test.randomForest <-sum(diag(confusion.test.randomForest))/sum(confusion.test.randomForest)
# confusion.test.randomForest
# accuracy.test.randomForest
list('测试集预测混淆矩阵'=confusion.test.randomForest, '测试集预测正确率'=accuracy.test.randomForest)
## $测试集预测混淆矩阵
##    
##      1  2  3  4  5
##   1  7  0  1  0  0
##   2  0  7  0  0  0
##   3  1  0 15  0  0
##   4  0  0  2  5  0
##   5  0  0  0  0  1
## 
## $测试集预测正确率
## [1] 0.8974359

神经网络

library(nnet)
nnet.model <- nnet(class~., traindata[,-2], size=30, decay=.001)
## # weights:  455
## initial  value 318.920319 
## iter  10 value 176.714302
## iter  20 value 57.798855
## iter  30 value 42.657486
## iter  40 value 27.296733
## iter  50 value 20.803959
## iter  60 value 18.519644
## iter  70 value 16.706718
## iter  80 value 15.700517
## iter  90 value 15.200025
## iter 100 value 14.797823
## final  value 14.797823 
## stopped after 100 iterations
summary(nnet.model)
## a 9-30-5 network with 455 weights
## options were - softmax modelling  decay=0.001
##  b->h1 i1->h1 i2->h1 i3->h1 i4->h1 i5->h1 i6->h1 i7->h1 i8->h1 i9->h1 
##  -2.75  -1.05  -1.31  -0.04   0.00   0.00  -0.03   0.06   0.00   0.11 
##  b->h2 i1->h2 i2->h2 i3->h2 i4->h2 i5->h2 i6->h2 i7->h2 i8->h2 i9->h2 
##   1.55  -2.29  -0.37  -0.76   1.02   1.46   1.91  -1.90  -2.21  -2.26 
##  b->h3 i1->h3 i2->h3 i3->h3 i4->h3 i5->h3 i6->h3 i7->h3 i8->h3 i9->h3 
##   3.06   2.93   2.01 -17.11   1.57   0.56   0.62  -0.89   0.67   3.71 
##  b->h4 i1->h4 i2->h4 i3->h4 i4->h4 i5->h4 i6->h4 i7->h4 i8->h4 i9->h4 
##  13.76 -20.60  -2.70 -13.91   0.05   0.26   1.69  -0.41  -0.87  -1.86 
##  b->h5 i1->h5 i2->h5 i3->h5 i4->h5 i5->h5 i6->h5 i7->h5 i8->h5 i9->h5 
##   8.63  -7.74  -8.29  -0.52  -5.14  -4.83  -5.11   6.94   2.07   0.17 
##  b->h6 i1->h6 i2->h6 i3->h6 i4->h6 i5->h6 i6->h6 i7->h6 i8->h6 i9->h6 
##   2.16  -7.64   0.96   4.96   1.28   2.07   2.49  -2.65  -1.87  -3.63 
##  b->h7 i1->h7 i2->h7 i3->h7 i4->h7 i5->h7 i6->h7 i7->h7 i8->h7 i9->h7 
##   7.74  -7.29  -6.89  -4.14  -1.00  -0.61   0.63   1.61  -1.54  -5.57 
##  b->h8 i1->h8 i2->h8 i3->h8 i4->h8 i5->h8 i6->h8 i7->h8 i8->h8 i9->h8 
##  -6.20   6.18   5.23  -0.35   4.25   3.92   4.70  -5.18  -2.24  -3.47 
##  b->h9 i1->h9 i2->h9 i3->h9 i4->h9 i5->h9 i6->h9 i7->h9 i8->h9 i9->h9 
##   7.43  -6.77 -11.18   7.93  -5.95  -5.05  -4.73   7.39   1.18  -4.61 
##  b->h10 i1->h10 i2->h10 i3->h10 i4->h10 i5->h10 i6->h10 i7->h10 i8->h10 
##    2.12    0.33    0.54   -0.99    0.11    0.04    0.11   -0.03   -0.09 
## i9->h10 
##    0.06 
##  b->h11 i1->h11 i2->h11 i3->h11 i4->h11 i5->h11 i6->h11 i7->h11 i8->h11 
##   -2.55    0.01   -0.82   -0.21   -0.22   -0.18   -0.32    0.06    0.12 
## i9->h11 
##    0.54 
##  b->h12 i1->h12 i2->h12 i3->h12 i4->h12 i5->h12 i6->h12 i7->h12 i8->h12 
##  -18.76   15.10    9.42   20.70    1.89    0.88    2.24    1.13    3.40 
## i9->h12 
##  -11.18 
##  b->h13 i1->h13 i2->h13 i3->h13 i4->h13 i5->h13 i6->h13 i7->h13 i8->h13 
##    2.17  -11.66    0.77   13.47   -2.00   -0.48   -1.18   -0.16   -0.14 
## i9->h13 
##   -0.44 
##  b->h14 i1->h14 i2->h14 i3->h14 i4->h14 i5->h14 i6->h14 i7->h14 i8->h14 
##    4.90  -14.11    4.32   -7.64    1.13    1.22    1.62   -2.77   -0.60 
## i9->h14 
##    1.82 
##  b->h15 i1->h15 i2->h15 i3->h15 i4->h15 i5->h15 i6->h15 i7->h15 i8->h15 
##   -2.00   -0.21   -1.04   -0.65   -0.22   -0.17   -0.26    0.19    0.06 
## i9->h15 
##    0.34 
##  b->h16 i1->h16 i2->h16 i3->h16 i4->h16 i5->h16 i6->h16 i7->h16 i8->h16 
##    0.55   -0.72    1.13    1.70    0.21    0.33    0.16   -0.40   -0.18 
## i9->h16 
##    0.23 
##  b->h17 i1->h17 i2->h17 i3->h17 i4->h17 i5->h17 i6->h17 i7->h17 i8->h17 
##    1.95   -1.02    0.93   -0.71    0.08    0.13    0.02   -0.18   -0.07 
## i9->h17 
##   -0.02 
##  b->h18 i1->h18 i2->h18 i3->h18 i4->h18 i5->h18 i6->h18 i7->h18 i8->h18 
##   -1.94    0.39   -0.65   -0.33   -0.43   -0.58   -0.58    0.56    0.36 
## i9->h18 
##    0.89 
##  b->h19 i1->h19 i2->h19 i3->h19 i4->h19 i5->h19 i6->h19 i7->h19 i8->h19 
##   -2.89   -0.62   -1.17   -0.62   -0.03   -0.05   -0.15    0.05    0.05 
## i9->h19 
##    0.25 
##  b->h20 i1->h20 i2->h20 i3->h20 i4->h20 i5->h20 i6->h20 i7->h20 i8->h20 
##    2.69    0.93    1.39    0.74    0.30    0.32    0.45   -0.33   -0.34 
## i9->h20 
##   -0.31 
##  b->h21 i1->h21 i2->h21 i3->h21 i4->h21 i5->h21 i6->h21 i7->h21 i8->h21 
##   -2.97   -0.45   -1.26    0.46   -0.13   -0.19   -0.35    0.24    0.15 
## i9->h21 
##    0.53 
##  b->h22 i1->h22 i2->h22 i3->h22 i4->h22 i5->h22 i6->h22 i7->h22 i8->h22 
##   -2.02   -0.48   -1.09   -0.70   -0.07   -0.14   -0.26    0.21    0.04 
## i9->h22 
##    0.34 
##  b->h23 i1->h23 i2->h23 i3->h23 i4->h23 i5->h23 i6->h23 i7->h23 i8->h23 
##   11.00   -9.85   -5.03   -7.26   -5.00   -5.03   -6.66    6.29    3.49 
## i9->h23 
##    9.93 
##  b->h24 i1->h24 i2->h24 i3->h24 i4->h24 i5->h24 i6->h24 i7->h24 i8->h24 
##    0.09    0.10    1.19    0.87    0.15    0.18    0.02   -0.27   -0.03 
## i9->h24 
##    0.35 
##  b->h25 i1->h25 i2->h25 i3->h25 i4->h25 i5->h25 i6->h25 i7->h25 i8->h25 
##   -1.65    4.19   -0.24   -1.84   -1.58   -2.09   -3.09    2.29    2.50 
## i9->h25 
##    6.02 
##  b->h26 i1->h26 i2->h26 i3->h26 i4->h26 i5->h26 i6->h26 i7->h26 i8->h26 
##    1.60    2.12    0.63   -9.24    3.25    3.09    3.24   -3.76   -2.22 
## i9->h26 
##   -0.40 
##  b->h27 i1->h27 i2->h27 i3->h27 i4->h27 i5->h27 i6->h27 i7->h27 i8->h27 
##   -1.77    1.13   -1.39   -1.13   -0.43   -0.47   -0.68    0.41    0.18 
## i9->h27 
##    1.08 
##  b->h28 i1->h28 i2->h28 i3->h28 i4->h28 i5->h28 i6->h28 i7->h28 i8->h28 
##   -0.24    4.65    0.83   -9.53    2.28    2.06    2.00   -2.98   -2.04 
## i9->h28 
##    1.40 
##  b->h29 i1->h29 i2->h29 i3->h29 i4->h29 i5->h29 i6->h29 i7->h29 i8->h29 
##   -2.92   -0.57   -1.21    0.07   -0.18   -0.08   -0.14    0.13    0.06 
## i9->h29 
##    0.25 
##  b->h30 i1->h30 i2->h30 i3->h30 i4->h30 i5->h30 i6->h30 i7->h30 i8->h30 
##   -2.17    2.89    2.08   -0.17   -0.80   -1.19   -2.03    1.25    2.02 
## i9->h30 
##    5.09 
##   b->o1  h1->o1  h2->o1  h3->o1  h4->o1  h5->o1  h6->o1  h7->o1  h8->o1 
##   -1.61   -0.73   -1.36   11.20   -5.48   -8.67   -3.12   -5.21    5.32 
##  h9->o1 h10->o1 h11->o1 h12->o1 h13->o1 h14->o1 h15->o1 h16->o1 h17->o1 
##  -12.47   -0.23   -0.50   15.65  -11.70   -3.57   -1.02   -1.60   -0.80 
## h18->o1 h19->o1 h20->o1 h21->o1 h22->o1 h23->o1 h24->o1 h25->o1 h26->o1 
##    0.30   -0.47    1.03   -2.01   -0.76   -4.20   -0.88    3.70    3.09 
## h27->o1 h28->o1 h29->o1 h30->o1 
##   -0.48    3.23   -0.84    2.52 
##   b->o2  h1->o2  h2->o2  h3->o2  h4->o2  h5->o2  h6->o2  h7->o2  h8->o2 
##    4.22   -0.06   -2.83  -10.27   -4.22    5.12    1.71   -2.68   -4.57 
##  h9->o2 h10->o2 h11->o2 h12->o2 h13->o2 h14->o2 h15->o2 h16->o2 h17->o2 
##    8.36   -1.34   -0.73    5.57   13.82   -2.43   -0.22    1.78    0.33 
## h18->o2 h19->o2 h20->o2 h21->o2 h22->o2 h23->o2 h24->o2 h25->o2 h26->o2 
##   -0.10   -0.19   -0.19    0.28   -0.18    6.00    1.17    1.99  -10.20 
## h27->o2 h28->o2 h29->o2 h30->o2 
##   -0.72   -9.77   -0.24    0.65 
##   b->o3  h1->o3  h2->o3  h3->o3  h4->o3  h5->o3  h6->o3  h7->o3  h8->o3 
##   -1.54    4.15   -0.36    5.06  -15.39   -0.59   -4.92   -3.20    0.79 
##  h9->o3 h10->o3 h11->o3 h12->o3 h13->o3 h14->o3 h15->o3 h16->o3 h17->o3 
##   -6.78    2.10    2.95  -16.51   -4.10   -4.52    2.53    0.26    0.79 
## h18->o3 h19->o3 h20->o3 h21->o3 h22->o3 h23->o3 h24->o3 h25->o3 h26->o3 
##    1.46    3.31   -1.69    5.18    2.72    7.33    1.37    3.03    5.39 
## h27->o3 h28->o3 h29->o3 h30->o3 
##    1.82    8.05    5.08    3.49 
##   b->o4  h1->o4  h2->o4  h3->o4  h4->o4  h5->o4  h6->o4  h7->o4  h8->o4 
##   -0.22    0.95   -0.08   -2.06    5.33   11.89   -2.25    8.77   -5.54 
##  h9->o4 h10->o4 h11->o4 h12->o4 h13->o4 h14->o4 h15->o4 h16->o4 h17->o4 
##   13.74   -2.61    0.54   -9.44   -4.01   -4.70    1.03   -2.56   -1.48 
## h18->o4 h19->o4 h20->o4 h21->o4 h22->o4 h23->o4 h24->o4 h25->o4 h26->o4 
##    0.51    0.74   -2.16    1.63    1.09    5.32   -2.31   -0.28   -0.19 
## h27->o4 h28->o4 h29->o4 h30->o4 
##    1.54   -0.34    1.04   -2.99 
##   b->o5  h1->o5  h2->o5  h3->o5  h4->o5  h5->o5  h6->o5  h7->o5  h8->o5 
##   -0.96   -4.20    4.76   -4.01   19.82   -7.68    8.59    2.49    3.97 
##  h9->o5 h10->o5 h11->o5 h12->o5 h13->o5 h14->o5 h15->o5 h16->o5 h17->o5 
##   -2.97    2.07   -2.33    4.71    6.04   15.20   -2.48    2.17    1.26 
## h18->o5 h19->o5 h20->o5 h21->o5 h22->o5 h23->o5 h24->o5 h25->o5 h26->o5 
##   -1.96   -3.36    2.99   -5.00   -2.73  -14.51    0.60   -8.41    1.90 
## h27->o5 h28->o5 h29->o5 h30->o5 
##   -2.11   -1.15   -5.09   -3.74
nnet.model
## a 9-30-5 network with 455 weights
## inputs: R1 G1 B1 R2 G2 B2 R3 G3 B3 
## output(s): class 
## options were - softmax modelling  decay=0.001
#训练集
confusion.train.nnet=table(traindata$class,predict(nnet.model,traindata[,-2],type="class"))
accuracy.train.nnet=sum(diag(confusion.train.nnet))/sum(confusion.train.nnet)
# confusion.train.nnet
# accuracy.train.nnet
list('训练集预测混淆矩阵'=confusion.train.nnet, '训练集预测正确率'=accuracy.train.nnet)
## $训练集预测混淆矩阵
##    
##      1  2  3  4  5
##   1 43  0  0  0  0
##   2  0 37  0  0  0
##   3  0  0 62  0  0
##   4  0  0  0 17  0
##   5  0  0  0  0  5
## 
## $训练集预测正确率
## [1] 1
#测试集
confusion.test.nnet=table(testdata$class,predict(nnet.model,testdata[,-2],type="class"))
accuracy.test.nnet=sum(diag(confusion.test.nnet))/sum(confusion.test.nnet)
# confusion.test.nnet
# accuracy.test.nnet
list('测试集预测混淆矩阵'=confusion.test.nnet, '测试集预测正确率'=accuracy.test.nnet)
## $测试集预测混淆矩阵
##    
##      1  2  3  4  5
##   1  8  0  0  0  0
##   2  0  7  0  0  0
##   3  0  0 16  0  0
##   4  0  0  1  6  0
##   5  0  0  0  0  1
## 
## $测试集预测正确率
## [1] 0.974359

各种算法的准确率对比

  accuracy.svm <- c(accuracy.train.svm, accuracy.test.svm)
  accuracy.randomForest <- c(accuracy.train.randomForest, accuracy.test.randomForest)
  accuracy.nnet <- c(accuracy.train.nnet, accuracy.test.nnet)
  accuracy.data <- data.frame(accuracy.svm, accuracy.randomForest, accuracy.nnet)
  accuracy.data
##   accuracy.svm accuracy.randomForest accuracy.nnet
## 1    0.9634146             1.0000000      1.000000
## 2    0.9230769             0.8974359      0.974359

第一行是训练集准确率,第二行是测试集准确率。

  • 1.支持向量机虽然在训练集拟合度不如随机森林和神经网络,但是测试集准确率较高;

  • 2.随机森林明显过拟合;

  • 3.对比发现神经网络不管训练集还是测试集效果都最好。

该对比只是简单的对比,不能直接说明哪种算法最好,原因:

  • 1.数据样本过少;

  • 2.实际使用算法中还要考虑到算法运行的时间,当面对海量数据时,准确复杂的算法往往运行过慢。

  • 3.算法得出的模型“好坏”,不仅仅只看准确率,还要看其他指标比如:recall、percision、F1-score等。比如地震预测更看重recall指标

  • 4.实际中还是要结合具体情况,选择合适的算法。