在R中有多个包可以支持KNN的计算,如下:

KNN回归

1. 使用knnreg()函数进行回归计算。

#加载数据集BloodBrain,用到向量logBBB和数据框bbbDescr
library(caret)
data(BloodBrain)
class(logBBB)
## [1] "numeric"
dim(bbbDescr)
## [1] 208 134
# ?createDataPartition
# 取约80%的观测作训练集。
inTrain <- createDataPartition(logBBB, p = 0.8)[[1]]
trainX <- bbbDescr[inTrain,] 
trainY <- logBBB[inTrain]
testX <- bbbDescr[-inTrain,]
testY <- logBBB[-inTrain]
#构建KNN回归模型
fit <- knnreg(trainX, trainY, k = 3) 
fit
## 3-nearest neighbor regression model
#KNN回归模型预测测试集
pred <- predict(fit, testX)
#计算回归模型的MSE
mean((pred-testY)^2)
## [1] 0.3956547

这个KNN回归模型的MSE只有0.58,可见回归效果很不错,偏差很小!下面用可视化图形比较一下结果。

library(ggplot2)
#将训练集、测试集和预测值结果集中比较
df <- data.frame(class=c(rep("trainY",length(trainY)),rep("testY",length(testY)),rep("predY",length(pred))),Yval=c(trainY,testY,pred))
ggplot(data=df, mapping = aes(x=Yval,fill=class)) + geom_dotplot(alpha=0.8)

这是dotplot,横坐标才是响应变量的值,纵坐标表频率。比较相邻的红色点和绿色点在横轴上的差异,即表明测试集中预测值与实际值的差距。

#比较测试集的预测值和实际值
df2 <- data.frame(testY,pred)
ggplot(data=df2,mapping = aes(x=testY,y=pred)) +
    geom_point(color="steelblue",size=3) +
    geom_abline(slope = 1,size=1.5,linetype=2)

这张散点图则直接将测试集中的实际值和预测值进行对比,虚线是\(Y=X\)。点离这条虚线越近,表明预测值和实际值之间的差异就越小。

2. 使用kknn()函数进行回归计算。

# 加载数据集BloodBrain,用到向量logBBB和数据框bbbDescr
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
data(BloodBrain)
# class(logBBB)
# dim(bbbDescr)
# length(logBBB)

# 组合成所需的数据框数据格式
BloodBrain_test <- data.frame(bbbDescr, logBBB=logBBB)
dim(BloodBrain_test)
## [1] 208 135
head(BloodBrain_test)
##    tpsa nbasic negative   vsa_hyd a_aro  weight peoe_vsa.0 peoe_vsa.1
## 1 12.03      1        0 167.06700     0 156.293   76.94749   43.44619
## 2 49.33      0        0  92.64243     6 151.165   38.24339   25.52006
## 3 50.53      1        0 295.16700    15 366.485   58.05473  124.74020
## 4 37.39      0        0 319.11220    15 382.552   62.23933  124.74020
## 5 37.39      1        0 299.65800    12 326.464   74.80064  118.04060
## 6 37.39      1        0 289.77770    11 332.492   74.80064  109.50990
##   peoe_vsa.2 peoe_vsa.3 peoe_vsa.4 peoe_vsa.5 peoe_vsa.6 peoe_vsa.0.1
## 1    0.00000   0.000000    0.00000          0  17.238030     18.74768
## 2    0.00000   8.619013   23.27370          0   0.000000     49.01962
## 3   21.65084   8.619013   17.44054          0   8.619013     83.82487
## 4   13.19232  21.785640    0.00000          0   8.619013     83.82487
## 5   33.00190   0.000000    0.00000          0   8.619013     83.82487
## 6   13.19232  21.785640    0.00000          0   8.619013     73.54603
##   peoe_vsa.1.1 peoe_vsa.2.1 peoe_vsa.3.1 peoe_vsa.4.1 peoe_vsa.5.1
## 1     43.50657            0            0     0.000000     0.000000
## 2      0.00000            0            0     0.000000    13.566920
## 3     49.01962            0            0     5.682576     2.503756
## 4     68.78024            0            0     5.682576     0.000000
## 5     36.76471            0            0     5.682576     0.136891
## 6     44.27042            0            0     5.682576     0.000000
##   peoe_vsa.6.1 a_acc a_acid a_base   vsa_acc vsa_acid vsa_base  vsa_don
## 1     0.000000     0      0      1  0.000000        0 5.682576 5.682576
## 2     7.904431     2      0      0 13.566920        0 0.000000 5.682576
## 3     2.640647     2      0      1  8.186332        0 0.000000 5.682576
## 4     2.640647     2      0      1  8.186332        0 0.000000 5.682576
## 5     2.503756     2      0      1  8.186332        0 0.000000 5.682576
## 6     2.640647     2      0      1  8.186332        0 0.000000 5.682576
##   vsa_other  vsa_pol slogp_vsa0 slogp_vsa1 slogp_vsa2 slogp_vsa3
## 1   0.00000  0.00000   18.01075    0.00000   3.981969      0.000
## 2  28.10760 13.56692   25.38523   23.26954  23.862220      0.000
## 3  43.56089  0.00000   14.12420   34.79628   0.000000     76.245
## 4  28.32470  0.00000   14.12420   34.79628   0.000000     76.245
## 5  19.64908  0.00000   14.12420   34.79628   0.000000     76.245
## 6  21.62514  0.00000   14.12420   34.79628   0.000000     76.245
##   slogp_vsa4 slogp_vsa5 slogp_vsa6 slogp_vsa7 slogp_vsa8 slogp_vsa9
## 1   4.410796  32.897190          0    0.00000  113.21040   33.32602
## 2   0.000000   0.000000          0   70.57274    0.00000   41.32619
## 3   3.185575   9.507346          0  148.12580   75.47363   28.27417
## 4   3.185575   0.000000          0  144.03540   75.47363   55.46144
## 5   3.185575   0.000000          0  140.71660   75.47363   26.01093
## 6   3.185575   0.000000          0  103.60310   75.47363   55.46144
##    smr_vsa0 smr_vsa1 smr_vsa2 smr_vsa3 smr_vsa4  smr_vsa5  smr_vsa6
## 1  0.000000 18.01075 4.410796 3.981969  0.00000 113.21040  0.000000
## 2 23.862220 25.38523 0.000000 5.243428 20.76750  70.57274  5.258784
## 3 12.631660 27.78542 0.000000 8.429003 29.58226 235.05870 76.245000
## 4  3.124314 27.78542 0.000000 8.429003 21.40142 235.05870 76.245000
## 5  3.124314 27.78542 0.000000 8.429003 20.33867 234.62990 76.245000
## 6  3.124314 27.78542 0.000000 8.429003 18.51150 197.51630 76.245000
##   smr_vsa7 tpsa.1 logp.o.w. frac.anion7. frac.cation7. andrewbind
## 1 66.22321  16.61     2.948        0.000         0.999        3.4
## 2 33.32602  49.33     0.889        0.001         0.000       -3.3
## 3  0.00000  51.73     4.439        0.000         0.986       12.8
## 4 31.27769  38.59     5.254        0.000         0.986       12.8
## 5  0.00000  38.59     3.800        0.000         0.986       10.3
## 6 31.27769  38.59     3.608        0.000         0.986       10.0
##   rotatablebonds   mlogp    clogp       mw nocount hbdnr
## 1              3 2.50245 2.970000 155.2856       1     1
## 2              2 1.05973 0.494000 151.1664       3     2
## 3              8 4.66091 5.136999 365.4794       5     1
## 4              8 3.82458 5.877599 381.5440       4     1
## 5              8 3.27214 4.367000 325.4577       4     1
## 6              8 2.89481 4.283600 331.4835       4     1
##   rule.of.5violations alert prx  ub pol inthb    adistm  adistd polar_area
## 1                   0     0   0 0.0   0     0    0.0000  0.0000    21.1242
## 2                   0     0   1 3.0   2     0  395.3757 10.8921   117.4081
## 3                   1     0   6 5.3   3     0 1364.5514 25.6784    82.0943
## 4                   1     0   2 5.3   3     0  702.6387 10.0232    65.0890
## 5                   0     0   2 4.2   2     0  745.5096 10.5753    66.1754
## 6                   0     0   2 3.6   2     0  779.2914 10.7712    69.0895
##   nonpolar_area psa_npsa   tcsa   tcpa   tcnp ovality surface_area
## 1      379.0702   0.0557 0.0097 0.1842 0.0103  1.0960     400.1944
## 2      247.5371   0.4743 0.0134 0.0417 0.0198  1.1173     364.9453
## 3      637.7242   0.1287 0.0111 0.0972 0.0125  1.3005     719.8185
## 4      667.9713   0.0974 0.0108 0.1218 0.0119  1.3013     733.0603
## 5      601.7463   0.1100 0.0118 0.1186 0.0130  1.2711     667.9218
## 6      588.6569   0.1174 0.0111 0.1061 0.0125  1.2642     657.7465
##      volume most_negative_charge most_positive_charge sum_absolute_charge
## 1  656.0650              -0.6174               0.3068              3.8918
## 2  555.0969              -0.8397               0.4967              4.8925
## 3 1224.4553              -0.8012               0.5414              7.9796
## 4 1257.2002              -0.7608               0.4800              7.9308
## 5 1132.6826              -0.8567               0.4547              7.8516
## 6 1115.8672              -0.7672               0.4349              7.3305
##   dipole_moment    homo    lumo hardness    ppsa1     ppsa2   ppsa3
## 1        1.1898 -9.6672  3.4038   6.5355 349.1390  679.3832 30.9705
## 2        4.2109 -8.9618  0.1942   4.5780 223.1310  545.8328 42.3030
## 3        3.5234 -8.6271  0.0589   4.3430 517.8218 2066.0186 63.9503
## 4        3.1463 -8.5592 -0.2651   4.1471 507.6144 2012.9060 61.6890
## 5        3.2676 -8.6732  0.3149   4.4940 509.1635 1998.8743 61.5645
## 6        3.2845 -8.6843 -0.0310   4.3266 473.5681 1735.7426 58.4993
##      pnsa1     pnsa2    pnsa3  fpsa1  fpsa2  fpsa3  fnsa1   fnsa2   fnsa3
## 1  51.0554  -99.3477 -10.4876 0.8724 1.6976 0.0774 0.1276 -0.2482 -0.0262
## 2 141.8143 -346.9123 -44.0368 0.6114 1.4957 0.1159 0.3886 -0.9506 -0.1207
## 3 201.9967 -805.9311 -43.7587 0.7194 2.8702 0.0888 0.2806 -1.1196 -0.0608
## 4 225.4459 -893.9880 -42.0328 0.6925 2.7459 0.0842 0.3075 -1.2195 -0.0573
## 5 158.7582 -623.2529 -39.8413 0.7623 2.9927 0.0922 0.2377 -0.9331 -0.0596
## 6 184.1784 -675.0588 -41.2100 0.7200 2.6389 0.0889 0.2800 -1.0263 -0.0627
##      wpsa1     wpsa2   wpsa3    wnsa1     wnsa2    wnsa3    dpsa1
## 1 139.7235  271.8854 12.3942  20.4321  -39.7584  -4.1971 298.0836
## 2  81.4306  199.1991 15.4383  51.7544 -126.6040 -16.0710  81.3167
## 3 372.7377 1487.1583 46.0326 145.4010 -580.1241 -31.4983 315.8251
## 4 372.1120 1475.5815 45.2218 165.2654 -655.3471 -30.8126 282.1685
## 5 340.0814 1335.0917 41.1203 106.0381 -416.2842 -26.6109 350.4053
## 6 311.4878 1141.6785 38.4777 121.1427 -444.0175 -27.1057 289.3897
##       dpsa2    dpsa3   rpcg   rncg   wpcs   wncs   sadh1   sadh2  sadh3
## 1  778.7310  41.4580 0.1577 0.3173 2.3805 1.9117 15.0988 15.0988 0.0377
## 2  892.7451  86.3398 0.2030 0.3433 1.3116 2.2546 45.2163 22.6082 0.1239
## 3 2871.9497 107.7089 0.1357 0.2008 1.1351 1.5725 16.7192 16.7192 0.0232
## 4 2906.8940 103.7218 0.1210 0.1919 0.7623 1.5302 17.2491 17.2491 0.0235
## 5 2622.1272 101.4058 0.1158 0.2182 0.7884 1.6795 16.0252 16.0252 0.0240
## 6 2410.8013  99.7093 0.1187 0.2093 2.0505 1.6760 17.2815 17.2815 0.0263
##    chdh1  chdh2  chdh3   scdh1  scdh2  scdh3   saaa1   saaa2  saaa3
## 1 0.3068 0.3068 0.0008  4.6321 4.6321 0.0116  6.0255  6.0255 0.0151
## 2 0.7960 0.3980 0.0022 17.6195 8.8098 0.0483 65.6236 32.8118 0.1798
## 3 0.4550 0.4550 0.0006  7.6077 7.6077 0.0106 57.5440 14.3860 0.0799
## 4 0.4354 0.4354 0.0006  7.5102 7.5102 0.0102 39.8638 13.2879 0.0544
## 5 0.4366 0.4366 0.0007  6.9970 6.9970 0.0105 42.4544 14.1515 0.0636
## 6 0.4349 0.4349 0.0007  7.5157 7.5157 0.0114 43.8012 14.6004 0.0666
##     chaa1   chaa2   chaa3    scaa1    scaa2   scaa3 ctdh ctaa   mchg
## 1 -0.6174 -0.6174 -0.0015  -3.7199  -3.7199 -0.0093    1    1 0.9241
## 2 -0.8371 -0.4185 -0.0023 -27.5143 -13.7571 -0.0754    2    2 1.2685
## 3 -1.3671 -0.3418 -0.0019 -21.7898  -5.4475 -0.0303    1    4 1.2562
## 4 -1.2332 -0.4111 -0.0017 -17.5957  -5.8652 -0.0240    1    3 1.1962
## 5 -1.1480 -0.3827 -0.0017 -17.0447  -5.6816 -0.0255    1    3 1.2934
## 6 -1.2317 -0.4106 -0.0019 -19.8513  -6.6171 -0.0302    1    3 1.2021
##     achg   rdta   n_sp2   n_sp3   o_sp2   o_sp3 logBBB
## 1 0.9241 1.0000  0.0000  6.0255  0.0000  0.0000   1.08
## 2 1.0420 1.0000  0.0000  6.5681 32.0102 33.6135  -0.40
## 3 1.2562 0.2500 26.9733 10.8567  0.0000 27.5451   0.22
## 4 1.1962 0.3333 21.7065 11.0017  0.0000 15.1316   0.14
## 5 1.2934 0.3333 24.2061 10.8109  0.0000 15.1333   0.69
## 6 1.2021 0.3333 25.5529 11.1218  0.0000 15.1333   0.44
library(kknn)
## 
## Attaching package: 'kknn'
## The following object is masked from 'package:caret':
## 
##     contr.dummy
inTrain <- createDataPartition(logBBB, p = 0.8)[[1]]
trainX <- BloodBrain_test[inTrain,] 
testX <- BloodBrain_test[-inTrain,]

trainX.kknn <- kknn(logBBB~., trainX, testX[,-ncol(testX)], distance = 1, kernel = "triangular")
# KNN回归模型预测测试集
pred <- predict(trainX.kknn)
pred
##  [1]  0.20460815  0.17110175 -0.79033452 -0.73467880 -0.68385002
##  [6]  0.68398606  0.15470046  0.87431166  0.57579189  0.76480532
## [11]  0.34400660 -0.50014623 -0.46812515 -0.72411570  0.49250559
## [16]  0.46730599  0.54249320 -0.37007600  0.86117961 -0.09477226
## [21]  0.26808981  0.30562356 -0.80339838  0.36469258 -0.28810768
## [26]  0.35239309 -0.99240068  0.09739690 -0.09953798 -0.73764659
## [31]  0.82226362 -0.63633017 -0.36868268  0.34509227  0.06668604
## [36] -1.29037267 -0.93974827  0.31578603 -1.38733982  0.45296064
# pred <- fitted(trainX.kknn)
# 计算回归模型的MSE
testX[,ncol(testX)]
##  [1]  0.22  0.69 -0.66 -0.12 -0.27  0.52 -0.62 -0.23  0.88  1.53  0.46
## [12] -0.03  0.34  0.00  0.11  1.26  0.85 -0.05  0.22  0.66  0.60  0.08
## [23] -0.79  0.41 -0.65  0.56 -1.30  0.56 -0.30 -0.78  0.49 -0.20 -1.30
## [34]  0.39 -0.28 -0.82 -1.57  0.54 -2.00 -0.02
mean((pred-testX[,ncol(testX)])^2)
## [1] 0.2492669

这个KNN回归模型的MSE只有0.3051442,偏差进一步减小了!

KNN分类

1. 使用knn()函数进行分类计算。

library(ISLR)
str(Caravan)
## 'data.frame':    5822 obs. of  86 variables:
##  $ MOSTYPE : num  33 37 37 9 40 23 39 33 33 11 ...
##  $ MAANTHUI: num  1 1 1 1 1 1 2 1 1 2 ...
##  $ MGEMOMV : num  3 2 2 3 4 2 3 2 2 3 ...
##  $ MGEMLEEF: num  2 2 2 3 2 1 2 3 4 3 ...
##  $ MOSHOOFD: num  8 8 8 3 10 5 9 8 8 3 ...
##  $ MGODRK  : num  0 1 0 2 1 0 2 0 0 3 ...
##  $ MGODPR  : num  5 4 4 3 4 5 2 7 1 5 ...
##  $ MGODOV  : num  1 1 2 2 1 0 0 0 3 0 ...
##  $ MGODGE  : num  3 4 4 4 4 5 5 2 6 2 ...
##  $ MRELGE  : num  7 6 3 5 7 0 7 7 6 7 ...
##  $ MRELSA  : num  0 2 2 2 1 6 2 2 0 0 ...
##  $ MRELOV  : num  2 2 4 2 2 3 0 0 3 2 ...
##  $ MFALLEEN: num  1 0 4 2 2 3 0 0 3 2 ...
##  $ MFGEKIND: num  2 4 4 3 4 5 3 5 3 2 ...
##  $ MFWEKIND: num  6 5 2 4 4 2 6 4 3 6 ...
##  $ MOPLHOOG: num  1 0 0 3 5 0 0 0 0 0 ...
##  $ MOPLMIDD: num  2 5 5 4 4 5 4 3 1 4 ...
##  $ MOPLLAAG: num  7 4 4 2 0 4 5 6 8 5 ...
##  $ MBERHOOG: num  1 0 0 4 0 2 0 2 1 2 ...
##  $ MBERZELF: num  0 0 0 0 5 0 0 0 1 0 ...
##  $ MBERBOER: num  1 0 0 0 4 0 0 0 0 0 ...
##  $ MBERMIDD: num  2 5 7 3 0 4 4 2 1 3 ...
##  $ MBERARBG: num  5 0 0 1 0 2 1 5 8 3 ...
##  $ MBERARBO: num  2 4 2 2 0 2 5 2 1 3 ...
##  $ MSKA    : num  1 0 0 3 9 2 0 2 1 1 ...
##  $ MSKB1   : num  1 2 5 2 0 2 1 1 1 2 ...
##  $ MSKB2   : num  2 3 0 1 0 2 4 2 0 1 ...
##  $ MSKC    : num  6 5 4 4 0 4 5 5 8 4 ...
##  $ MSKD    : num  1 0 0 0 0 2 0 2 1 2 ...
##  $ MHHUUR  : num  1 2 7 5 4 9 6 0 9 0 ...
##  $ MHKOOP  : num  8 7 2 4 5 0 3 9 0 9 ...
##  $ MAUT1   : num  8 7 7 9 6 5 8 4 5 6 ...
##  $ MAUT2   : num  0 1 0 0 2 3 0 4 2 1 ...
##  $ MAUT0   : num  1 2 2 0 1 3 1 2 3 2 ...
##  $ MZFONDS : num  8 6 9 7 5 9 9 6 7 6 ...
##  $ MZPART  : num  1 3 0 2 4 0 0 3 2 3 ...
##  $ MINKM30 : num  0 2 4 1 0 5 4 2 7 2 ...
##  $ MINK3045: num  4 0 5 5 0 2 3 5 2 3 ...
##  $ MINK4575: num  5 5 0 3 9 3 3 3 1 3 ...
##  $ MINK7512: num  0 2 0 0 0 0 0 0 0 1 ...
##  $ MINK123M: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ MINKGEM : num  4 5 3 4 6 3 3 3 2 4 ...
##  $ MKOOPKLA: num  3 4 4 4 3 3 5 3 3 7 ...
##  $ PWAPART : num  0 2 2 0 0 0 0 0 0 2 ...
##  $ PWABEDR : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PWALAND : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PPERSAUT: num  6 0 6 6 0 6 6 0 5 0 ...
##  $ PBESAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PMOTSCO : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PVRAAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PAANHANG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PTRACTOR: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PWERKT  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PBROM   : num  0 0 0 0 0 0 0 3 0 0 ...
##  $ PLEVEN  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PPERSONG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PGEZONG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PWAOREG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PBRAND  : num  5 2 2 2 6 0 0 0 0 3 ...
##  $ PZEILPL : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PPLEZIER: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PFIETS  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PINBOED : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PBYSTAND: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWAPART : num  0 2 1 0 0 0 0 0 0 1 ...
##  $ AWABEDR : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWALAND : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ APERSAUT: num  1 0 1 1 0 1 1 0 1 0 ...
##  $ ABESAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AMOTSCO : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AVRAAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AAANHANG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ATRACTOR: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWERKT  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ABROM   : num  0 0 0 0 0 0 0 1 0 0 ...
##  $ ALEVEN  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ APERSONG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AGEZONG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWAOREG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ABRAND  : num  1 1 1 1 1 0 0 0 0 1 ...
##  $ AZEILPL : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ APLEZIER: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AFIETS  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AINBOED : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ABYSTAND: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ Purchase: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
table(Caravan$Purchase)/sum(as.numeric(table(Caravan$Purchase)))
## 
##         No        Yes 
## 0.94022673 0.05977327
standardized.X=scale(Caravan[,-86])

mean(standardized.X[,sample(1:85,1)])
## [1] -2.047306e-18
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] -1.028437e-16
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] -8.267509e-18
var(standardized.X[,sample(1:85,1)])
## [1] 1
library(FNN)
# 前1000观测作为测试集,其他当训练集
test <- 1:1000
train.X <- standardized.X[-test,]
test.X <- standardized.X[test,]
train.Y <- Caravan$Purchase[-test]
test.Y <- Caravan$Purchase[test]
knn.pred <- knn(train.X,test.X,train.Y,k=1)

mean(test.Y!=knn.pred)
## [1] 0.117
mean(test.Y!="No")
## [1] 0.059
# sum(test.Y!=knn.pred)/length(test.Y)
# sum(test.Y!="No")/length(test.Y)

2. 使用kknn()函数进行分类计算。

library(kknn)

# ?kknn
data(iris)
m <- dim(iris)[1]
val <- sample(1:m, size = round(m/3), replace = FALSE, prob = rep(1/m, m)) 
iris.learn <- iris[-val,]
iris.valid <- iris[val,]
iris.kknn <- kknn(Species~., iris.learn, iris.valid, distance = 1, kernel = "triangular")
summary(iris.kknn)
## 
## Call:
## kknn(formula = Species ~ ., train = iris.learn, test = iris.valid,     distance = 1, kernel = "triangular")
## 
## Response: "nominal"
##           fit prob.setosa prob.versicolor prob.virginica
## 1  versicolor           0       1.0000000      0.0000000
## 2   virginica           0       0.1631409      0.8368591
## 3  versicolor           0       1.0000000      0.0000000
## 4   virginica           0       0.0000000      1.0000000
## 5   virginica           0       0.3982418      0.6017582
## 6   virginica           0       0.0000000      1.0000000
## 7      setosa           1       0.0000000      0.0000000
## 8   virginica           0       0.1072497      0.8927503
## 9  versicolor           0       1.0000000      0.0000000
## 10 versicolor           0       1.0000000      0.0000000
## 11  virginica           0       0.0000000      1.0000000
## 12 versicolor           0       1.0000000      0.0000000
## 13 versicolor           0       0.8040346      0.1959654
## 14     setosa           1       0.0000000      0.0000000
## 15 versicolor           0       0.6420934      0.3579066
## 16     setosa           1       0.0000000      0.0000000
## 17 versicolor           0       1.0000000      0.0000000
## 18 versicolor           0       1.0000000      0.0000000
## 19 versicolor           0       1.0000000      0.0000000
## 20  virginica           0       0.2602202      0.7397798
## 21 versicolor           0       1.0000000      0.0000000
## 22 versicolor           0       0.8556834      0.1443166
## 23     setosa           1       0.0000000      0.0000000
## 24  virginica           0       0.0000000      1.0000000
## 25     setosa           1       0.0000000      0.0000000
## 26     setosa           1       0.0000000      0.0000000
## 27  virginica           0       0.0000000      1.0000000
## 28  virginica           0       0.0000000      1.0000000
## 29     setosa           1       0.0000000      0.0000000
## 30     setosa           1       0.0000000      0.0000000
## 31  virginica           0       0.0000000      1.0000000
## 32     setosa           1       0.0000000      0.0000000
## 33     setosa           1       0.0000000      0.0000000
## 34  virginica           0       0.0000000      1.0000000
## 35 versicolor           0       1.0000000      0.0000000
## 36 versicolor           0       1.0000000      0.0000000
## 37  virginica           0       0.0000000      1.0000000
## 38  virginica           0       0.4488633      0.5511367
## 39  virginica           0       0.1932859      0.8067141
## 40     setosa           1       0.0000000      0.0000000
## 41     setosa           1       0.0000000      0.0000000
## 42  virginica           0       0.0000000      1.0000000
## 43 versicolor           0       0.8047641      0.1952359
## 44  virginica           0       0.0000000      1.0000000
## 45     setosa           1       0.0000000      0.0000000
## 46  virginica           0       0.0000000      1.0000000
## 47     setosa           1       0.0000000      0.0000000
## 48  virginica           0       0.0000000      1.0000000
## 49     setosa           1       0.0000000      0.0000000
## 50  virginica           0       0.0000000      1.0000000
fit <- fitted(iris.kknn)
# fit <- predict(iris.kknn)
table(iris.valid$Species, fit)
##             fit
##              setosa versicolor virginica
##   setosa         15          0         0
##   versicolor      0         13         1
##   virginica       0          2        19
pcol <- as.character(as.numeric(iris.valid$Species))
pairs(iris.valid[1:4], pch = pcol, col = c("green3", "red")[(iris.valid$Species != fit)+1])

library(kknn)
data(ionosphere)
ionosphere.learn <- ionosphere[1:200,]
ionosphere.valid <- ionosphere[-c(1:200),]
fit.kknn <- kknn(class ~ ., ionosphere.learn, ionosphere.valid)
table(ionosphere.valid$class, fit.kknn$fit)
##    
##       b   g
##   b  19   8
##   g   2 122
(fit.train1 <- train.kknn(class ~ ., ionosphere.learn, kmax = 15, 
                          kernel = c("triangular", "rectangular", "epanechnikov", "optimal"), distance = 1))
## 
## Call:
## train.kknn(formula = class ~ ., data = ionosphere.learn, kmax = 15,     distance = 1, kernel = c("triangular", "rectangular", "epanechnikov",         "optimal"))
## 
## Type of response variable: nominal
## Minimal misclassification: 0.12
## Best kernel: rectangular
## Best k: 2
table(predict(fit.train1, ionosphere.valid), ionosphere.valid$class)
##    
##       b   g
##   b  25   4
##   g   2 120
(fit.train2 <- train.kknn(class ~ ., ionosphere.learn, kmax = 15, 
                          kernel = c("triangular", "rectangular", "epanechnikov", "optimal"), distance = 2))
## 
## Call:
## train.kknn(formula = class ~ ., data = ionosphere.learn, kmax = 15,     distance = 2, kernel = c("triangular", "rectangular", "epanechnikov",         "optimal"))
## 
## Type of response variable: nominal
## Minimal misclassification: 0.12
## Best kernel: rectangular
## Best k: 2
table(predict(fit.train2, ionosphere.valid), ionosphere.valid$class)
##    
##       b   g
##   b  20   5
##   g   7 119

KNN在保险业中挖掘潜在用户的应用

这里应用ISLR包里的Caravan数据集,先大致浏览一下:

library(ISLR)
str(Caravan)
## 'data.frame':    5822 obs. of  86 variables:
##  $ MOSTYPE : num  33 37 37 9 40 23 39 33 33 11 ...
##  $ MAANTHUI: num  1 1 1 1 1 1 2 1 1 2 ...
##  $ MGEMOMV : num  3 2 2 3 4 2 3 2 2 3 ...
##  $ MGEMLEEF: num  2 2 2 3 2 1 2 3 4 3 ...
##  $ MOSHOOFD: num  8 8 8 3 10 5 9 8 8 3 ...
##  $ MGODRK  : num  0 1 0 2 1 0 2 0 0 3 ...
##  $ MGODPR  : num  5 4 4 3 4 5 2 7 1 5 ...
##  $ MGODOV  : num  1 1 2 2 1 0 0 0 3 0 ...
##  $ MGODGE  : num  3 4 4 4 4 5 5 2 6 2 ...
##  $ MRELGE  : num  7 6 3 5 7 0 7 7 6 7 ...
##  $ MRELSA  : num  0 2 2 2 1 6 2 2 0 0 ...
##  $ MRELOV  : num  2 2 4 2 2 3 0 0 3 2 ...
##  $ MFALLEEN: num  1 0 4 2 2 3 0 0 3 2 ...
##  $ MFGEKIND: num  2 4 4 3 4 5 3 5 3 2 ...
##  $ MFWEKIND: num  6 5 2 4 4 2 6 4 3 6 ...
##  $ MOPLHOOG: num  1 0 0 3 5 0 0 0 0 0 ...
##  $ MOPLMIDD: num  2 5 5 4 4 5 4 3 1 4 ...
##  $ MOPLLAAG: num  7 4 4 2 0 4 5 6 8 5 ...
##  $ MBERHOOG: num  1 0 0 4 0 2 0 2 1 2 ...
##  $ MBERZELF: num  0 0 0 0 5 0 0 0 1 0 ...
##  $ MBERBOER: num  1 0 0 0 4 0 0 0 0 0 ...
##  $ MBERMIDD: num  2 5 7 3 0 4 4 2 1 3 ...
##  $ MBERARBG: num  5 0 0 1 0 2 1 5 8 3 ...
##  $ MBERARBO: num  2 4 2 2 0 2 5 2 1 3 ...
##  $ MSKA    : num  1 0 0 3 9 2 0 2 1 1 ...
##  $ MSKB1   : num  1 2 5 2 0 2 1 1 1 2 ...
##  $ MSKB2   : num  2 3 0 1 0 2 4 2 0 1 ...
##  $ MSKC    : num  6 5 4 4 0 4 5 5 8 4 ...
##  $ MSKD    : num  1 0 0 0 0 2 0 2 1 2 ...
##  $ MHHUUR  : num  1 2 7 5 4 9 6 0 9 0 ...
##  $ MHKOOP  : num  8 7 2 4 5 0 3 9 0 9 ...
##  $ MAUT1   : num  8 7 7 9 6 5 8 4 5 6 ...
##  $ MAUT2   : num  0 1 0 0 2 3 0 4 2 1 ...
##  $ MAUT0   : num  1 2 2 0 1 3 1 2 3 2 ...
##  $ MZFONDS : num  8 6 9 7 5 9 9 6 7 6 ...
##  $ MZPART  : num  1 3 0 2 4 0 0 3 2 3 ...
##  $ MINKM30 : num  0 2 4 1 0 5 4 2 7 2 ...
##  $ MINK3045: num  4 0 5 5 0 2 3 5 2 3 ...
##  $ MINK4575: num  5 5 0 3 9 3 3 3 1 3 ...
##  $ MINK7512: num  0 2 0 0 0 0 0 0 0 1 ...
##  $ MINK123M: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ MINKGEM : num  4 5 3 4 6 3 3 3 2 4 ...
##  $ MKOOPKLA: num  3 4 4 4 3 3 5 3 3 7 ...
##  $ PWAPART : num  0 2 2 0 0 0 0 0 0 2 ...
##  $ PWABEDR : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PWALAND : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PPERSAUT: num  6 0 6 6 0 6 6 0 5 0 ...
##  $ PBESAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PMOTSCO : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PVRAAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PAANHANG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PTRACTOR: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PWERKT  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PBROM   : num  0 0 0 0 0 0 0 3 0 0 ...
##  $ PLEVEN  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PPERSONG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PGEZONG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PWAOREG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PBRAND  : num  5 2 2 2 6 0 0 0 0 3 ...
##  $ PZEILPL : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PPLEZIER: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PFIETS  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PINBOED : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ PBYSTAND: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWAPART : num  0 2 1 0 0 0 0 0 0 1 ...
##  $ AWABEDR : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWALAND : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ APERSAUT: num  1 0 1 1 0 1 1 0 1 0 ...
##  $ ABESAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AMOTSCO : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AVRAAUT : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AAANHANG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ATRACTOR: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWERKT  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ABROM   : num  0 0 0 0 0 0 0 1 0 0 ...
##  $ ALEVEN  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ APERSONG: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AGEZONG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AWAOREG : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ABRAND  : num  1 1 1 1 1 0 0 0 0 1 ...
##  $ AZEILPL : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ APLEZIER: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AFIETS  : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ AINBOED : num  0 0 0 0 0 0 0 0 0 0 ...
##  $ ABYSTAND: num  0 0 0 0 0 0 0 0 0 0 ...
##  $ Purchase: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
table(Caravan$Purchase)/sum(as.numeric(table(Caravan$Purchase)))
## 
##         No        Yes 
## 0.94022673 0.05977327

5822行观测,86个变量,其中只有Purchase是分类型变量,其他全是数值型变量。Purchase两个水平,No和Yes分别表示不买或买保险。可见到有约6%的人买了保险。

由于KNN算法要计算距离,这85个数值型变量量纲不同,相同两个点在不同特征变量上的距离差值可能非常大。因此要归一化,这是Machine Learning的常识。这里直接用scale()函数将各连续型变量进行正态标准化,即转化为服从均值为0,标准差为1的正态分布。

standardized.X=scale(Caravan[,-86])

mean(standardized.X[,sample(1:85,1)])
## [1] -1.352488e-16
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] 3.030345e-17
var(standardized.X[,sample(1:85,1)])
## [1] 1
mean(standardized.X[,sample(1:85,1)])
## [1] 9.158575e-18
var(standardized.X[,sample(1:85,1)])
## [1] 1

可见随机抽取一个标准化后的变量,基本都是均值约为0,标准差为1。

library(FNN)
# 前1000观测作为测试集,其他当训练集
test <- 1:1000
train.X <- standardized.X[-test,]
test.X <- standardized.X[test,]
train.Y <- Caravan$Purchase[-test]
test.Y <- Caravan$Purchase[test]
knn.pred <- knn(train.X,test.X,train.Y,k=1)

mean(test.Y!=knn.pred)
## [1] 0.117
mean(test.Y!="No")
## [1] 0.059
# sum(test.Y!=knn.pred)/length(test.Y)
# sum(test.Y!="No")/length(test.Y)

当K=1时,KNN总体的分类结果在测试集上的错误率约为12%。由于大部分的人都不买保险(先验概率只有6%),那么如果模型预测不买保险的准确率应当很高,纠结于预测不买保险实际上却买保险的样本没有意义,同样的也不必考虑整体的准确率(Accuracy)。作为保险销售人员,只需要关心在模型预测下会买保险的人中有多少真正会买保险,这是精准营销的精确度(Precision);因此,在这样的业务背景中,应该着重分析模型的Precesion,而不是Accuracy。

table(knn.pred,test.Y)
##         test.Y
## knn.pred  No Yes
##      No  874  50
##      Yes  67   9
table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2]
##       Yes 
## 0.1184211

可见K=1时,KNN模型的Precision约为12%,是随机猜测概率(6%)的两倍!

下面尝试K取不同的值:

knn.pred <- knn(train.X,test.X,train.Y,k=3)
table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2]
##       Yes 
## 0.1923077
knn.pred <- knn(train.X,test.X,train.Y,k=5)
table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2]
##       Yes 
## 0.2857143

可以发现当K=3时,Precision=20%;当K=5时,Precision=28%。

作为对比,这个案例再用逻辑回归做一次!

glm.fit <- glm(Purchase~.,data=Caravan,family = binomial,subset = -test)
glm.probs <- predict(glm.fit,Caravan[test,],type = "response")
glm.pred <- ifelse(glm.probs > 0.5,"Yes","No")
table(glm.pred,test.Y)
##         test.Y
## glm.pred  No Yes
##      No  934  59
##      Yes   7   0

这个分类效果就差很多,Precision竟然是0!事实上,分类概率阈值为0.5是针对等可能事件,但买不买保险显然不是等可能事件,把阈值降低到0.25再看看:

glm.pred <- ifelse(glm.probs >0.25,"Yes","No")
table(glm.pred,test.Y)
##         test.Y
## glm.pred  No Yes
##      No  919  48
##      Yes  22  11

这下子Precision就达到1/3了,比随机猜测的精确度高出5倍不止!

以上试验都充分表明,通过机器学习算法进行精准营销的精确度比随机猜测的效果要强好几倍!

参考:

http://www.cnblogs.com/cloudtj/p/6688037.html

http://www.cnblogs.com/Leo_wl/p/5602481.html

http://datartisan.com/article/detail/86.html