http://scikit-learn.org/stable/modules/ensemble.html

分类-RandomForestClassifier

http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier

# !/usr/bin/env python
# -*- coding: utf-8 -*-
# __author__ = "abdata"

import time
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
# from sklearn import datasets
# import matplotlib.pyplot as plt

mdata = pd.read_csv('http://data.galaxystatistics.com/blog_data/regression/iris.csv')

X = mdata.iloc[:,2:6]
print(X.head())
print(X.shape)
print(type(X))
y = mdata.iloc[:,1]

t0 = time.time()
clf = RandomForestClassifier(n_estimators=100)
clf = clf.fit(X, y)
t = time.time() - t0

print(t)

print(clf.predict(X))
print(type(clf))


######
### 示例1:
######
# from sklearn.datasets import load_iris
# from sklearn.ensemble import RandomForestClassifier
# import pandas as pd
# import numpy as np
# 
# iris = load_iris()
# df = pd.DataFrame(iris.data, columns=iris.feature_names)
# 
# print(df)
# print(iris.target)
# print(iris.target_names)
# 
# df['is_train'] = np.random.uniform(0, 1, len(df)) <= .75
# # df['species'] = pd.Factor(iris.target, iris.target_names)
# df['species'] = iris.target
# 
# print(df.head())
# 
# train, test = df[df['is_train']==True], df[df['is_train']==False]
# 
# features = df.columns[:4]
# print(features)
# 
# clf = RandomForestClassifier(n_jobs=2)
# 
# y, _ = pd.factorize(train['species'])
# print(y)
# 
# clf.fit(train[features], y)
# 
# preds = iris.target_names[clf.predict(test[features])]
# print(preds)
# 
# result = pd.crosstab(test['species'], preds, rownames=['actual'], colnames=['preds'])
# print(result)


######
### 示例2:
######
# from sklearn.ensemble import RandomForestClassifier
# X = [[0, 0], [1, 1]]
# Y = [0, 1]
# clf = RandomForestClassifier(n_estimators=10)
# clf = clf.fit(X, Y)
# print(clf)
# print(clf.predict([[0.5, 0.5]]))
##    Sepal.Length  Sepal.Width  Petal.Length  Petal.Width
## 0           5.1          3.5           1.4          0.2
## 1           4.9          3.0           1.4          0.2
## 2           4.7          3.2           1.3          0.2
## 3           4.6          3.1           1.5          0.2
## 4           5.0          3.6           1.4          0.2
## (150, 4)
## <class 'pandas.core.frame.DataFrame'>
## 0.11530709266662598
## ['setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
##  'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
##  'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
##  'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
##  'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
##  'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
##  'setosa' 'setosa' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
##  'versicolor' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
##  'virginica' 'virginica' 'virginica']
## <class 'sklearn.ensemble.forest.RandomForestClassifier'>

回归-RandomForestRegressor

http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html#sklearn.ensemble.RandomForestRegressor

# !/usr/bin/env python
# -*- coding: utf-8 -*-
# __author__ = "abdata"


import time
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
# from sklearn import datasets
# import matplotlib.pyplot as plt

mdata = pd.read_csv('http://data.galaxystatistics.com/blog_data/regression/iris.csv')

X = mdata.iloc[:,3:6]
print(X.head())
print(X.shape)
print(type(X))
y = mdata.iloc[:,2]

t0 = time.time()
clf = RandomForestRegressor(n_estimators=100)
clf = clf.fit(X, y)
t = time.time() - t0

print(t)

pre = clf.predict(X)
print(pre)
print(pre.shape)

print(type(clf))
##    Sepal.Width  Petal.Length  Petal.Width
## 0          3.5           1.4          0.2
## 1          3.0           1.4          0.2
## 2          3.2           1.3          0.2
## 3          3.1           1.5          0.2
## 4          3.6           1.4          0.2
## (150, 3)
## <class 'pandas.core.frame.DataFrame'>
## 0.09976530075073242
## [ 5.118       4.7915      4.5557      4.7512      5.036       5.4175
##   4.80016667  5.0558881   4.5185      4.83625     5.31638333  4.9         4.771
##   4.4946      5.65515     5.5415      5.436       5.054       5.5365
##   5.18841667  5.20841667  5.17883333  4.93        5.07        4.9755      4.931
##   5.049       5.17900476  5.08        4.743       4.8315      5.24696667
##   5.33575     5.46065     4.7512      4.8064      5.348       4.968       4.5081
##   5.0558881   5.126       4.505       4.5557      5.04        5.2205      4.762
##   5.14273333  4.632       5.31638333  4.89766667  6.793       6.31133333
##   6.794       5.596       6.34733333  5.91266667  6.366       4.967       6.444
##   5.452       5.19        5.763       5.826       6.266       5.632       6.595
##   5.61326667  5.777       5.944       5.581       6.118       5.924       6.23
##   6.173       6.209       6.38853333  6.591       6.503       5.9101      5.48
##   5.53        5.54        5.73        6.052       5.61326667  6.09766667
##   6.69133333  6.13316667  5.692       5.536       5.72316667  6.25243333
##   5.749       5.          5.68633333  5.73566667  5.76483333  6.209       5.077
##   5.766       6.604       5.839       6.956       6.355       6.667       7.615
##   5.33133333  7.387       6.711       7.308       6.45        6.277       6.721
##   5.788       5.879       6.516       6.508       7.663       7.662       6.105
##   6.814       5.806       7.683       6.202       6.69        7.059       6.203
##   6.093       6.417       6.917       7.393       7.7         6.412       6.302
##   6.232       7.519       6.322       6.512       6.03333333  6.767       6.693
##   6.737       5.839       6.859       6.619       6.588       6.092       6.431
##   6.27        6.066     ]
## (150,)
## <class 'sklearn.ensemble.forest.RandomForestRegressor'>