在R中有多个包可以支持KNN的计算,如下:
library(kknn) 主要方法:kknn【分类和回归通用】
library(FNN) 主要方法:knn【分类函数】
library(caret) 主要方法:knnreg【回归函数】
#加载数据集BloodBrain,用到向量logBBB和数据框bbbDescr
library(caret)
data(BloodBrain)
class(logBBB)
## [1] "numeric"
dim(bbbDescr)
## [1] 208 134
# ?createDataPartition
# 取约80%的观测作训练集。
inTrain <- createDataPartition(logBBB, p = 0.8)[[1]]
trainX <- bbbDescr[inTrain,]
trainY <- logBBB[inTrain]
testX <- bbbDescr[-inTrain,]
testY <- logBBB[-inTrain]
#构建KNN回归模型
fit <- knnreg(trainX, trainY, k = 3)
fit
## 3-nearest neighbor regression model
#KNN回归模型预测测试集
pred <- predict(fit, testX)
#计算回归模型的MSE
mean((pred-testY)^2)
## [1] 0.3956547
这个KNN回归模型的MSE只有0.58,可见回归效果很不错,偏差很小!下面用可视化图形比较一下结果。
library(ggplot2)
#将训练集、测试集和预测值结果集中比较
df <- data.frame(class=c(rep("trainY",length(trainY)),rep("testY",length(testY)),rep("predY",length(pred))),Yval=c(trainY,testY,pred))
ggplot(data=df, mapping = aes(x=Yval,fill=class)) + geom_dotplot(alpha=0.8)
这是dotplot,横坐标才是响应变量的值,纵坐标表频率。比较相邻的红色点和绿色点在横轴上的差异,即表明测试集中预测值与实际值的差距。
#比较测试集的预测值和实际值
df2 <- data.frame(testY,pred)
ggplot(data=df2,mapping = aes(x=testY,y=pred)) +
geom_point(color="steelblue",size=3) +
geom_abline(slope = 1,size=1.5,linetype=2)
这张散点图则直接将测试集中的实际值和预测值进行对比,虚线是\(Y=X\)。点离这条虚线越近,表明预测值和实际值之间的差异就越小。
# 加载数据集BloodBrain,用到向量logBBB和数据框bbbDescr
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
data(BloodBrain)
# class(logBBB)
# dim(bbbDescr)
# length(logBBB)
# 组合成所需的数据框数据格式
BloodBrain_test <- data.frame(bbbDescr, logBBB=logBBB)
dim(BloodBrain_test)
## [1] 208 135
head(BloodBrain_test)
## tpsa nbasic negative vsa_hyd a_aro weight peoe_vsa.0 peoe_vsa.1
## 1 12.03 1 0 167.06700 0 156.293 76.94749 43.44619
## 2 49.33 0 0 92.64243 6 151.165 38.24339 25.52006
## 3 50.53 1 0 295.16700 15 366.485 58.05473 124.74020
## 4 37.39 0 0 319.11220 15 382.552 62.23933 124.74020
## 5 37.39 1 0 299.65800 12 326.464 74.80064 118.04060
## 6 37.39 1 0 289.77770 11 332.492 74.80064 109.50990
## peoe_vsa.2 peoe_vsa.3 peoe_vsa.4 peoe_vsa.5 peoe_vsa.6 peoe_vsa.0.1
## 1 0.00000 0.000000 0.00000 0 17.238030 18.74768
## 2 0.00000 8.619013 23.27370 0 0.000000 49.01962
## 3 21.65084 8.619013 17.44054 0 8.619013 83.82487
## 4 13.19232 21.785640 0.00000 0 8.619013 83.82487
## 5 33.00190 0.000000 0.00000 0 8.619013 83.82487
## 6 13.19232 21.785640 0.00000 0 8.619013 73.54603
## peoe_vsa.1.1 peoe_vsa.2.1 peoe_vsa.3.1 peoe_vsa.4.1 peoe_vsa.5.1
## 1 43.50657 0 0 0.000000 0.000000
## 2 0.00000 0 0 0.000000 13.566920
## 3 49.01962 0 0 5.682576 2.503756
## 4 68.78024 0 0 5.682576 0.000000
## 5 36.76471 0 0 5.682576 0.136891
## 6 44.27042 0 0 5.682576 0.000000
## peoe_vsa.6.1 a_acc a_acid a_base vsa_acc vsa_acid vsa_base vsa_don
## 1 0.000000 0 0 1 0.000000 0 5.682576 5.682576
## 2 7.904431 2 0 0 13.566920 0 0.000000 5.682576
## 3 2.640647 2 0 1 8.186332 0 0.000000 5.682576
## 4 2.640647 2 0 1 8.186332 0 0.000000 5.682576
## 5 2.503756 2 0 1 8.186332 0 0.000000 5.682576
## 6 2.640647 2 0 1 8.186332 0 0.000000 5.682576
## vsa_other vsa_pol slogp_vsa0 slogp_vsa1 slogp_vsa2 slogp_vsa3
## 1 0.00000 0.00000 18.01075 0.00000 3.981969 0.000
## 2 28.10760 13.56692 25.38523 23.26954 23.862220 0.000
## 3 43.56089 0.00000 14.12420 34.79628 0.000000 76.245
## 4 28.32470 0.00000 14.12420 34.79628 0.000000 76.245
## 5 19.64908 0.00000 14.12420 34.79628 0.000000 76.245
## 6 21.62514 0.00000 14.12420 34.79628 0.000000 76.245
## slogp_vsa4 slogp_vsa5 slogp_vsa6 slogp_vsa7 slogp_vsa8 slogp_vsa9
## 1 4.410796 32.897190 0 0.00000 113.21040 33.32602
## 2 0.000000 0.000000 0 70.57274 0.00000 41.32619
## 3 3.185575 9.507346 0 148.12580 75.47363 28.27417
## 4 3.185575 0.000000 0 144.03540 75.47363 55.46144
## 5 3.185575 0.000000 0 140.71660 75.47363 26.01093
## 6 3.185575 0.000000 0 103.60310 75.47363 55.46144
## smr_vsa0 smr_vsa1 smr_vsa2 smr_vsa3 smr_vsa4 smr_vsa5 smr_vsa6
## 1 0.000000 18.01075 4.410796 3.981969 0.00000 113.21040 0.000000
## 2 23.862220 25.38523 0.000000 5.243428 20.76750 70.57274 5.258784
## 3 12.631660 27.78542 0.000000 8.429003 29.58226 235.05870 76.245000
## 4 3.124314 27.78542 0.000000 8.429003 21.40142 235.05870 76.245000
## 5 3.124314 27.78542 0.000000 8.429003 20.33867 234.62990 76.245000
## 6 3.124314 27.78542 0.000000 8.429003 18.51150 197.51630 76.245000
## smr_vsa7 tpsa.1 logp.o.w. frac.anion7. frac.cation7. andrewbind
## 1 66.22321 16.61 2.948 0.000 0.999 3.4
## 2 33.32602 49.33 0.889 0.001 0.000 -3.3
## 3 0.00000 51.73 4.439 0.000 0.986 12.8
## 4 31.27769 38.59 5.254 0.000 0.986 12.8
## 5 0.00000 38.59 3.800 0.000 0.986 10.3
## 6 31.27769 38.59 3.608 0.000 0.986 10.0
## rotatablebonds mlogp clogp mw nocount hbdnr
## 1 3 2.50245 2.970000 155.2856 1 1
## 2 2 1.05973 0.494000 151.1664 3 2
## 3 8 4.66091 5.136999 365.4794 5 1
## 4 8 3.82458 5.877599 381.5440 4 1
## 5 8 3.27214 4.367000 325.4577 4 1
## 6 8 2.89481 4.283600 331.4835 4 1
## rule.of.5violations alert prx ub pol inthb adistm adistd polar_area
## 1 0 0 0 0.0 0 0 0.0000 0.0000 21.1242
## 2 0 0 1 3.0 2 0 395.3757 10.8921 117.4081
## 3 1 0 6 5.3 3 0 1364.5514 25.6784 82.0943
## 4 1 0 2 5.3 3 0 702.6387 10.0232 65.0890
## 5 0 0 2 4.2 2 0 745.5096 10.5753 66.1754
## 6 0 0 2 3.6 2 0 779.2914 10.7712 69.0895
## nonpolar_area psa_npsa tcsa tcpa tcnp ovality surface_area
## 1 379.0702 0.0557 0.0097 0.1842 0.0103 1.0960 400.1944
## 2 247.5371 0.4743 0.0134 0.0417 0.0198 1.1173 364.9453
## 3 637.7242 0.1287 0.0111 0.0972 0.0125 1.3005 719.8185
## 4 667.9713 0.0974 0.0108 0.1218 0.0119 1.3013 733.0603
## 5 601.7463 0.1100 0.0118 0.1186 0.0130 1.2711 667.9218
## 6 588.6569 0.1174 0.0111 0.1061 0.0125 1.2642 657.7465
## volume most_negative_charge most_positive_charge sum_absolute_charge
## 1 656.0650 -0.6174 0.3068 3.8918
## 2 555.0969 -0.8397 0.4967 4.8925
## 3 1224.4553 -0.8012 0.5414 7.9796
## 4 1257.2002 -0.7608 0.4800 7.9308
## 5 1132.6826 -0.8567 0.4547 7.8516
## 6 1115.8672 -0.7672 0.4349 7.3305
## dipole_moment homo lumo hardness ppsa1 ppsa2 ppsa3
## 1 1.1898 -9.6672 3.4038 6.5355 349.1390 679.3832 30.9705
## 2 4.2109 -8.9618 0.1942 4.5780 223.1310 545.8328 42.3030
## 3 3.5234 -8.6271 0.0589 4.3430 517.8218 2066.0186 63.9503
## 4 3.1463 -8.5592 -0.2651 4.1471 507.6144 2012.9060 61.6890
## 5 3.2676 -8.6732 0.3149 4.4940 509.1635 1998.8743 61.5645
## 6 3.2845 -8.6843 -0.0310 4.3266 473.5681 1735.7426 58.4993
## pnsa1 pnsa2 pnsa3 fpsa1 fpsa2 fpsa3 fnsa1 fnsa2 fnsa3
## 1 51.0554 -99.3477 -10.4876 0.8724 1.6976 0.0774 0.1276 -0.2482 -0.0262
## 2 141.8143 -346.9123 -44.0368 0.6114 1.4957 0.1159 0.3886 -0.9506 -0.1207
## 3 201.9967 -805.9311 -43.7587 0.7194 2.8702 0.0888 0.2806 -1.1196 -0.0608
## 4 225.4459 -893.9880 -42.0328 0.6925 2.7459 0.0842 0.3075 -1.2195 -0.0573
## 5 158.7582 -623.2529 -39.8413 0.7623 2.9927 0.0922 0.2377 -0.9331 -0.0596
## 6 184.1784 -675.0588 -41.2100 0.7200 2.6389 0.0889 0.2800 -1.0263 -0.0627
## wpsa1 wpsa2 wpsa3 wnsa1 wnsa2 wnsa3 dpsa1
## 1 139.7235 271.8854 12.3942 20.4321 -39.7584 -4.1971 298.0836
## 2 81.4306 199.1991 15.4383 51.7544 -126.6040 -16.0710 81.3167
## 3 372.7377 1487.1583 46.0326 145.4010 -580.1241 -31.4983 315.8251
## 4 372.1120 1475.5815 45.2218 165.2654 -655.3471 -30.8126 282.1685
## 5 340.0814 1335.0917 41.1203 106.0381 -416.2842 -26.6109 350.4053
## 6 311.4878 1141.6785 38.4777 121.1427 -444.0175 -27.1057 289.3897
## dpsa2 dpsa3 rpcg rncg wpcs wncs sadh1 sadh2 sadh3
## 1 778.7310 41.4580 0.1577 0.3173 2.3805 1.9117 15.0988 15.0988 0.0377
## 2 892.7451 86.3398 0.2030 0.3433 1.3116 2.2546 45.2163 22.6082 0.1239
## 3 2871.9497 107.7089 0.1357 0.2008 1.1351 1.5725 16.7192 16.7192 0.0232
## 4 2906.8940 103.7218 0.1210 0.1919 0.7623 1.5302 17.2491 17.2491 0.0235
## 5 2622.1272 101.4058 0.1158 0.2182 0.7884 1.6795 16.0252 16.0252 0.0240
## 6 2410.8013 99.7093 0.1187 0.2093 2.0505 1.6760 17.2815 17.2815 0.0263
## chdh1 chdh2 chdh3 scdh1 scdh2 scdh3 saaa1 saaa2 saaa3
## 1 0.3068 0.3068 0.0008 4.6321 4.6321 0.0116 6.0255 6.0255 0.0151
## 2 0.7960 0.3980 0.0022 17.6195 8.8098 0.0483 65.6236 32.8118 0.1798
## 3 0.4550 0.4550 0.0006 7.6077 7.6077 0.0106 57.5440 14.3860 0.0799
## 4 0.4354 0.4354 0.0006 7.5102 7.5102 0.0102 39.8638 13.2879 0.0544
## 5 0.4366 0.4366 0.0007 6.9970 6.9970 0.0105 42.4544 14.1515 0.0636
## 6 0.4349 0.4349 0.0007 7.5157 7.5157 0.0114 43.8012 14.6004 0.0666
## chaa1 chaa2 chaa3 scaa1 scaa2 scaa3 ctdh ctaa mchg
## 1 -0.6174 -0.6174 -0.0015 -3.7199 -3.7199 -0.0093 1 1 0.9241
## 2 -0.8371 -0.4185 -0.0023 -27.5143 -13.7571 -0.0754 2 2 1.2685
## 3 -1.3671 -0.3418 -0.0019 -21.7898 -5.4475 -0.0303 1 4 1.2562
## 4 -1.2332 -0.4111 -0.0017 -17.5957 -5.8652 -0.0240 1 3 1.1962
## 5 -1.1480 -0.3827 -0.0017 -17.0447 -5.6816 -0.0255 1 3 1.2934
## 6 -1.2317 -0.4106 -0.0019 -19.8513 -6.6171 -0.0302 1 3 1.2021
## achg rdta n_sp2 n_sp3 o_sp2 o_sp3 logBBB
## 1 0.9241 1.0000 0.0000 6.0255 0.0000 0.0000 1.08
## 2 1.0420 1.0000 0.0000 6.5681 32.0102 33.6135 -0.40
## 3 1.2562 0.2500 26.9733 10.8567 0.0000 27.5451 0.22
## 4 1.1962 0.3333 21.7065 11.0017 0.0000 15.1316 0.14
## 5 1.2934 0.3333 24.2061 10.8109 0.0000 15.1333 0.69
## 6 1.2021 0.3333 25.5529 11.1218 0.0000 15.1333 0.44
library(kknn)
##
## Attaching package: 'kknn'
## The following object is masked from 'package:caret':
##
## contr.dummy
inTrain <- createDataPartition(logBBB, p = 0.8)[[1]]
trainX <- BloodBrain_test[inTrain,]
testX <- BloodBrain_test[-inTrain,]
trainX.kknn <- kknn(logBBB~., trainX, testX[,-ncol(testX)], distance = 1, kernel = "triangular")
# KNN回归模型预测测试集
pred <- predict(trainX.kknn)
pred
## [1] 0.20460815 0.17110175 -0.79033452 -0.73467880 -0.68385002
## [6] 0.68398606 0.15470046 0.87431166 0.57579189 0.76480532
## [11] 0.34400660 -0.50014623 -0.46812515 -0.72411570 0.49250559
## [16] 0.46730599 0.54249320 -0.37007600 0.86117961 -0.09477226
## [21] 0.26808981 0.30562356 -0.80339838 0.36469258 -0.28810768
## [26] 0.35239309 -0.99240068 0.09739690 -0.09953798 -0.73764659
## [31] 0.82226362 -0.63633017 -0.36868268 0.34509227 0.06668604
## [36] -1.29037267 -0.93974827 0.31578603 -1.38733982 0.45296064
# pred <- fitted(trainX.kknn)
# 计算回归模型的MSE
testX[,ncol(testX)]
## [1] 0.22 0.69 -0.66 -0.12 -0.27 0.52 -0.62 -0.23 0.88 1.53 0.46
## [12] -0.03 0.34 0.00 0.11 1.26 0.85 -0.05 0.22 0.66 0.60 0.08
## [23] -0.79 0.41 -0.65 0.56 -1.30 0.56 -0.30 -0.78 0.49 -0.20 -1.30
## [34] 0.39 -0.28 -0.82 -1.57 0.54 -2.00 -0.02
mean((pred-testX[,ncol(testX)])^2)
## [1] 0.2492669
这个KNN回归模型的MSE只有0.3051442,偏差进一步减小了!
library(ISLR)
str(Caravan)
## 'data.frame': 5822 obs. of 86 variables:
## $ MOSTYPE : num 33 37 37 9 40 23 39 33 33 11 ...
## $ MAANTHUI: num 1 1 1 1 1 1 2 1 1 2 ...
## $ MGEMOMV : num 3 2 2 3 4 2 3 2 2 3 ...
## $ MGEMLEEF: num 2 2 2 3 2 1 2 3 4 3 ...
## $ MOSHOOFD: num 8 8 8 3 10 5 9 8 8 3 ...
## $ MGODRK : num 0 1 0 2 1 0 2 0 0 3 ...
## $ MGODPR : num 5 4 4 3 4 5 2 7 1 5 ...
## $ MGODOV : num 1 1 2 2 1 0 0 0 3 0 ...
## $ MGODGE : num 3 4 4 4 4 5 5 2 6 2 ...
## $ MRELGE : num 7 6 3 5 7 0 7 7 6 7 ...
## $ MRELSA : num 0 2 2 2 1 6 2 2 0 0 ...
## $ MRELOV : num 2 2 4 2 2 3 0 0 3 2 ...
## $ MFALLEEN: num 1 0 4 2 2 3 0 0 3 2 ...
## $ MFGEKIND: num 2 4 4 3 4 5 3 5 3 2 ...
## $ MFWEKIND: num 6 5 2 4 4 2 6 4 3 6 ...
## $ MOPLHOOG: num 1 0 0 3 5 0 0 0 0 0 ...
## $ MOPLMIDD: num 2 5 5 4 4 5 4 3 1 4 ...
## $ MOPLLAAG: num 7 4 4 2 0 4 5 6 8 5 ...
## $ MBERHOOG: num 1 0 0 4 0 2 0 2 1 2 ...
## $ MBERZELF: num 0 0 0 0 5 0 0 0 1 0 ...
## $ MBERBOER: num 1 0 0 0 4 0 0 0 0 0 ...
## $ MBERMIDD: num 2 5 7 3 0 4 4 2 1 3 ...
## $ MBERARBG: num 5 0 0 1 0 2 1 5 8 3 ...
## $ MBERARBO: num 2 4 2 2 0 2 5 2 1 3 ...
## $ MSKA : num 1 0 0 3 9 2 0 2 1 1 ...
## $ MSKB1 : num 1 2 5 2 0 2 1 1 1 2 ...
## $ MSKB2 : num 2 3 0 1 0 2 4 2 0 1 ...
## $ MSKC : num 6 5 4 4 0 4 5 5 8 4 ...
## $ MSKD : num 1 0 0 0 0 2 0 2 1 2 ...
## $ MHHUUR : num 1 2 7 5 4 9 6 0 9 0 ...
## $ MHKOOP : num 8 7 2 4 5 0 3 9 0 9 ...
## $ MAUT1 : num 8 7 7 9 6 5 8 4 5 6 ...
## $ MAUT2 : num 0 1 0 0 2 3 0 4 2 1 ...
## $ MAUT0 : num 1 2 2 0 1 3 1 2 3 2 ...
## $ MZFONDS : num 8 6 9 7 5 9 9 6 7 6 ...
## $ MZPART : num 1 3 0 2 4 0 0 3 2 3 ...
## $ MINKM30 : num 0 2 4 1 0 5 4 2 7 2 ...
## $ MINK3045: num 4 0 5 5 0 2 3 5 2 3 ...
## $ MINK4575: num 5 5 0 3 9 3 3 3 1 3 ...
## $ MINK7512: num 0 2 0 0 0 0 0 0 0 1 ...
## $ MINK123M: num 0 0 0 0 0 0 0 0 0 0 ...
## $ MINKGEM : num 4 5 3 4 6 3 3 3 2 4 ...
## $ MKOOPKLA: num 3 4 4 4 3 3 5 3 3 7 ...
## $ PWAPART : num 0 2 2 0 0 0 0 0 0 2 ...
## $ PWABEDR : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PWALAND : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PPERSAUT: num 6 0 6 6 0 6 6 0 5 0 ...
## $ PBESAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PMOTSCO : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PVRAAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PAANHANG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PTRACTOR: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PWERKT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PBROM : num 0 0 0 0 0 0 0 3 0 0 ...
## $ PLEVEN : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PPERSONG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PGEZONG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PWAOREG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PBRAND : num 5 2 2 2 6 0 0 0 0 3 ...
## $ PZEILPL : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PPLEZIER: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PFIETS : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PINBOED : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PBYSTAND: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWAPART : num 0 2 1 0 0 0 0 0 0 1 ...
## $ AWABEDR : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWALAND : num 0 0 0 0 0 0 0 0 0 0 ...
## $ APERSAUT: num 1 0 1 1 0 1 1 0 1 0 ...
## $ ABESAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AMOTSCO : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AVRAAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AAANHANG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ ATRACTOR: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWERKT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ ABROM : num 0 0 0 0 0 0 0 1 0 0 ...
## $ ALEVEN : num 0 0 0 0 0 0 0 0 0 0 ...
## $ APERSONG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AGEZONG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWAOREG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ ABRAND : num 1 1 1 1 1 0 0 0 0 1 ...
## $ AZEILPL : num 0 0 0 0 0 0 0 0 0 0 ...
## $ APLEZIER: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AFIETS : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AINBOED : num 0 0 0 0 0 0 0 0 0 0 ...
## $ ABYSTAND: num 0 0 0 0 0 0 0 0 0 0 ...
## $ Purchase: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
table(Caravan$Purchase)/sum(as.numeric(table(Caravan$Purchase)))
##
## No Yes
## 0.94022673 0.05977327
standardized.X=scale(Caravan[,-86])
mean(standardized.X[,sample(1:85,1)])
## [1] -2.047306e-18
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] -1.028437e-16
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] -8.267509e-18
var(standardized.X[,sample(1:85,1)])
## [1] 1
library(FNN)
# 前1000观测作为测试集,其他当训练集
test <- 1:1000
train.X <- standardized.X[-test,]
test.X <- standardized.X[test,]
train.Y <- Caravan$Purchase[-test]
test.Y <- Caravan$Purchase[test]
knn.pred <- knn(train.X,test.X,train.Y,k=1)
mean(test.Y!=knn.pred)
## [1] 0.117
mean(test.Y!="No")
## [1] 0.059
# sum(test.Y!=knn.pred)/length(test.Y)
# sum(test.Y!="No")/length(test.Y)
library(kknn)
# ?kknn
data(iris)
m <- dim(iris)[1]
val <- sample(1:m, size = round(m/3), replace = FALSE, prob = rep(1/m, m))
iris.learn <- iris[-val,]
iris.valid <- iris[val,]
iris.kknn <- kknn(Species~., iris.learn, iris.valid, distance = 1, kernel = "triangular")
summary(iris.kknn)
##
## Call:
## kknn(formula = Species ~ ., train = iris.learn, test = iris.valid, distance = 1, kernel = "triangular")
##
## Response: "nominal"
## fit prob.setosa prob.versicolor prob.virginica
## 1 versicolor 0 1.0000000 0.0000000
## 2 virginica 0 0.1631409 0.8368591
## 3 versicolor 0 1.0000000 0.0000000
## 4 virginica 0 0.0000000 1.0000000
## 5 virginica 0 0.3982418 0.6017582
## 6 virginica 0 0.0000000 1.0000000
## 7 setosa 1 0.0000000 0.0000000
## 8 virginica 0 0.1072497 0.8927503
## 9 versicolor 0 1.0000000 0.0000000
## 10 versicolor 0 1.0000000 0.0000000
## 11 virginica 0 0.0000000 1.0000000
## 12 versicolor 0 1.0000000 0.0000000
## 13 versicolor 0 0.8040346 0.1959654
## 14 setosa 1 0.0000000 0.0000000
## 15 versicolor 0 0.6420934 0.3579066
## 16 setosa 1 0.0000000 0.0000000
## 17 versicolor 0 1.0000000 0.0000000
## 18 versicolor 0 1.0000000 0.0000000
## 19 versicolor 0 1.0000000 0.0000000
## 20 virginica 0 0.2602202 0.7397798
## 21 versicolor 0 1.0000000 0.0000000
## 22 versicolor 0 0.8556834 0.1443166
## 23 setosa 1 0.0000000 0.0000000
## 24 virginica 0 0.0000000 1.0000000
## 25 setosa 1 0.0000000 0.0000000
## 26 setosa 1 0.0000000 0.0000000
## 27 virginica 0 0.0000000 1.0000000
## 28 virginica 0 0.0000000 1.0000000
## 29 setosa 1 0.0000000 0.0000000
## 30 setosa 1 0.0000000 0.0000000
## 31 virginica 0 0.0000000 1.0000000
## 32 setosa 1 0.0000000 0.0000000
## 33 setosa 1 0.0000000 0.0000000
## 34 virginica 0 0.0000000 1.0000000
## 35 versicolor 0 1.0000000 0.0000000
## 36 versicolor 0 1.0000000 0.0000000
## 37 virginica 0 0.0000000 1.0000000
## 38 virginica 0 0.4488633 0.5511367
## 39 virginica 0 0.1932859 0.8067141
## 40 setosa 1 0.0000000 0.0000000
## 41 setosa 1 0.0000000 0.0000000
## 42 virginica 0 0.0000000 1.0000000
## 43 versicolor 0 0.8047641 0.1952359
## 44 virginica 0 0.0000000 1.0000000
## 45 setosa 1 0.0000000 0.0000000
## 46 virginica 0 0.0000000 1.0000000
## 47 setosa 1 0.0000000 0.0000000
## 48 virginica 0 0.0000000 1.0000000
## 49 setosa 1 0.0000000 0.0000000
## 50 virginica 0 0.0000000 1.0000000
fit <- fitted(iris.kknn)
# fit <- predict(iris.kknn)
table(iris.valid$Species, fit)
## fit
## setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 13 1
## virginica 0 2 19
pcol <- as.character(as.numeric(iris.valid$Species))
pairs(iris.valid[1:4], pch = pcol, col = c("green3", "red")[(iris.valid$Species != fit)+1])
library(kknn)
data(ionosphere)
ionosphere.learn <- ionosphere[1:200,]
ionosphere.valid <- ionosphere[-c(1:200),]
fit.kknn <- kknn(class ~ ., ionosphere.learn, ionosphere.valid)
table(ionosphere.valid$class, fit.kknn$fit)
##
## b g
## b 19 8
## g 2 122
(fit.train1 <- train.kknn(class ~ ., ionosphere.learn, kmax = 15,
kernel = c("triangular", "rectangular", "epanechnikov", "optimal"), distance = 1))
##
## Call:
## train.kknn(formula = class ~ ., data = ionosphere.learn, kmax = 15, distance = 1, kernel = c("triangular", "rectangular", "epanechnikov", "optimal"))
##
## Type of response variable: nominal
## Minimal misclassification: 0.12
## Best kernel: rectangular
## Best k: 2
table(predict(fit.train1, ionosphere.valid), ionosphere.valid$class)
##
## b g
## b 25 4
## g 2 120
(fit.train2 <- train.kknn(class ~ ., ionosphere.learn, kmax = 15,
kernel = c("triangular", "rectangular", "epanechnikov", "optimal"), distance = 2))
##
## Call:
## train.kknn(formula = class ~ ., data = ionosphere.learn, kmax = 15, distance = 2, kernel = c("triangular", "rectangular", "epanechnikov", "optimal"))
##
## Type of response variable: nominal
## Minimal misclassification: 0.12
## Best kernel: rectangular
## Best k: 2
table(predict(fit.train2, ionosphere.valid), ionosphere.valid$class)
##
## b g
## b 20 5
## g 7 119
这里应用ISLR包里的Caravan数据集,先大致浏览一下:
library(ISLR)
str(Caravan)
## 'data.frame': 5822 obs. of 86 variables:
## $ MOSTYPE : num 33 37 37 9 40 23 39 33 33 11 ...
## $ MAANTHUI: num 1 1 1 1 1 1 2 1 1 2 ...
## $ MGEMOMV : num 3 2 2 3 4 2 3 2 2 3 ...
## $ MGEMLEEF: num 2 2 2 3 2 1 2 3 4 3 ...
## $ MOSHOOFD: num 8 8 8 3 10 5 9 8 8 3 ...
## $ MGODRK : num 0 1 0 2 1 0 2 0 0 3 ...
## $ MGODPR : num 5 4 4 3 4 5 2 7 1 5 ...
## $ MGODOV : num 1 1 2 2 1 0 0 0 3 0 ...
## $ MGODGE : num 3 4 4 4 4 5 5 2 6 2 ...
## $ MRELGE : num 7 6 3 5 7 0 7 7 6 7 ...
## $ MRELSA : num 0 2 2 2 1 6 2 2 0 0 ...
## $ MRELOV : num 2 2 4 2 2 3 0 0 3 2 ...
## $ MFALLEEN: num 1 0 4 2 2 3 0 0 3 2 ...
## $ MFGEKIND: num 2 4 4 3 4 5 3 5 3 2 ...
## $ MFWEKIND: num 6 5 2 4 4 2 6 4 3 6 ...
## $ MOPLHOOG: num 1 0 0 3 5 0 0 0 0 0 ...
## $ MOPLMIDD: num 2 5 5 4 4 5 4 3 1 4 ...
## $ MOPLLAAG: num 7 4 4 2 0 4 5 6 8 5 ...
## $ MBERHOOG: num 1 0 0 4 0 2 0 2 1 2 ...
## $ MBERZELF: num 0 0 0 0 5 0 0 0 1 0 ...
## $ MBERBOER: num 1 0 0 0 4 0 0 0 0 0 ...
## $ MBERMIDD: num 2 5 7 3 0 4 4 2 1 3 ...
## $ MBERARBG: num 5 0 0 1 0 2 1 5 8 3 ...
## $ MBERARBO: num 2 4 2 2 0 2 5 2 1 3 ...
## $ MSKA : num 1 0 0 3 9 2 0 2 1 1 ...
## $ MSKB1 : num 1 2 5 2 0 2 1 1 1 2 ...
## $ MSKB2 : num 2 3 0 1 0 2 4 2 0 1 ...
## $ MSKC : num 6 5 4 4 0 4 5 5 8 4 ...
## $ MSKD : num 1 0 0 0 0 2 0 2 1 2 ...
## $ MHHUUR : num 1 2 7 5 4 9 6 0 9 0 ...
## $ MHKOOP : num 8 7 2 4 5 0 3 9 0 9 ...
## $ MAUT1 : num 8 7 7 9 6 5 8 4 5 6 ...
## $ MAUT2 : num 0 1 0 0 2 3 0 4 2 1 ...
## $ MAUT0 : num 1 2 2 0 1 3 1 2 3 2 ...
## $ MZFONDS : num 8 6 9 7 5 9 9 6 7 6 ...
## $ MZPART : num 1 3 0 2 4 0 0 3 2 3 ...
## $ MINKM30 : num 0 2 4 1 0 5 4 2 7 2 ...
## $ MINK3045: num 4 0 5 5 0 2 3 5 2 3 ...
## $ MINK4575: num 5 5 0 3 9 3 3 3 1 3 ...
## $ MINK7512: num 0 2 0 0 0 0 0 0 0 1 ...
## $ MINK123M: num 0 0 0 0 0 0 0 0 0 0 ...
## $ MINKGEM : num 4 5 3 4 6 3 3 3 2 4 ...
## $ MKOOPKLA: num 3 4 4 4 3 3 5 3 3 7 ...
## $ PWAPART : num 0 2 2 0 0 0 0 0 0 2 ...
## $ PWABEDR : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PWALAND : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PPERSAUT: num 6 0 6 6 0 6 6 0 5 0 ...
## $ PBESAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PMOTSCO : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PVRAAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PAANHANG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PTRACTOR: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PWERKT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PBROM : num 0 0 0 0 0 0 0 3 0 0 ...
## $ PLEVEN : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PPERSONG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PGEZONG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PWAOREG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PBRAND : num 5 2 2 2 6 0 0 0 0 3 ...
## $ PZEILPL : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PPLEZIER: num 0 0 0 0 0 0 0 0 0 0 ...
## $ PFIETS : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PINBOED : num 0 0 0 0 0 0 0 0 0 0 ...
## $ PBYSTAND: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWAPART : num 0 2 1 0 0 0 0 0 0 1 ...
## $ AWABEDR : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWALAND : num 0 0 0 0 0 0 0 0 0 0 ...
## $ APERSAUT: num 1 0 1 1 0 1 1 0 1 0 ...
## $ ABESAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AMOTSCO : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AVRAAUT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AAANHANG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ ATRACTOR: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWERKT : num 0 0 0 0 0 0 0 0 0 0 ...
## $ ABROM : num 0 0 0 0 0 0 0 1 0 0 ...
## $ ALEVEN : num 0 0 0 0 0 0 0 0 0 0 ...
## $ APERSONG: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AGEZONG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AWAOREG : num 0 0 0 0 0 0 0 0 0 0 ...
## $ ABRAND : num 1 1 1 1 1 0 0 0 0 1 ...
## $ AZEILPL : num 0 0 0 0 0 0 0 0 0 0 ...
## $ APLEZIER: num 0 0 0 0 0 0 0 0 0 0 ...
## $ AFIETS : num 0 0 0 0 0 0 0 0 0 0 ...
## $ AINBOED : num 0 0 0 0 0 0 0 0 0 0 ...
## $ ABYSTAND: num 0 0 0 0 0 0 0 0 0 0 ...
## $ Purchase: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
table(Caravan$Purchase)/sum(as.numeric(table(Caravan$Purchase)))
##
## No Yes
## 0.94022673 0.05977327
5822行观测,86个变量,其中只有Purchase是分类型变量,其他全是数值型变量。Purchase两个水平,No和Yes分别表示不买或买保险。可见到有约6%的人买了保险。
由于KNN算法要计算距离,这85个数值型变量量纲不同,相同两个点在不同特征变量上的距离差值可能非常大。因此要归一化,这是Machine Learning的常识。这里直接用scale()函数将各连续型变量进行正态标准化,即转化为服从均值为0,标准差为1的正态分布。
standardized.X=scale(Caravan[,-86])
mean(standardized.X[,sample(1:85,1)])
## [1] -1.352488e-16
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] 3.030345e-17
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] 9.158575e-18
var(standardized.X[,sample(1:85,1)])
## [1] 1
可见随机抽取一个标准化后的变量,基本都是均值约为0,标准差为1。
library(FNN)
# 前1000观测作为测试集,其他当训练集
test <- 1:1000
train.X <- standardized.X[-test,]
test.X <- standardized.X[test,]
train.Y <- Caravan$Purchase[-test]
test.Y <- Caravan$Purchase[test]
knn.pred <- knn(train.X,test.X,train.Y,k=1)
mean(test.Y!=knn.pred)
## [1] 0.117
mean(test.Y!="No")
## [1] 0.059
# sum(test.Y!=knn.pred)/length(test.Y)
# sum(test.Y!="No")/length(test.Y)
当K=1时,KNN总体的分类结果在测试集上的错误率约为12%。由于大部分的人都不买保险(先验概率只有6%),那么如果模型预测不买保险的准确率应当很高,纠结于预测不买保险实际上却买保险的样本没有意义,同样的也不必考虑整体的准确率(Accuracy)。作为保险销售人员,只需要关心在模型预测下会买保险的人中有多少真正会买保险,这是精准营销的精确度(Precision);因此,在这样的业务背景中,应该着重分析模型的Precesion,而不是Accuracy。
table(knn.pred,test.Y)
## test.Y
## knn.pred No Yes
## No 874 50
## Yes 67 9
table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2]
## Yes
## 0.1184211
可见K=1时,KNN模型的Precision约为12%,是随机猜测概率(6%)的两倍!
下面尝试K取不同的值:
knn.pred <- knn(train.X,test.X,train.Y,k=3)
table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2]
## Yes
## 0.1923077
knn.pred <- knn(train.X,test.X,train.Y,k=5)
table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2]
## Yes
## 0.2857143
可以发现当K=3时,Precision=20%;当K=5时,Precision=28%。
作为对比,这个案例再用逻辑回归做一次!
glm.fit <- glm(Purchase~.,data=Caravan,family = binomial,subset = -test)
glm.probs <- predict(glm.fit,Caravan[test,],type = "response")
glm.pred <- ifelse(glm.probs > 0.5,"Yes","No")
table(glm.pred,test.Y)
## test.Y
## glm.pred No Yes
## No 934 59
## Yes 7 0
这个分类效果就差很多,Precision竟然是0!事实上,分类概率阈值为0.5是针对等可能事件,但买不买保险显然不是等可能事件,把阈值降低到0.25再看看:
glm.pred <- ifelse(glm.probs >0.25,"Yes","No")
table(glm.pred,test.Y)
## test.Y
## glm.pred No Yes
## No 919 48
## Yes 22 11
这下子Precision就达到1/3了,比随机猜测的精确度高出5倍不止!
以上试验都充分表明,通过机器学习算法进行精准营销的精确度比随机猜测的效果要强好几倍!
参考:
http://www.cnblogs.com/cloudtj/p/6688037.html