2012-05-15 53 views
110

如何训练有素的朴素贝叶斯分类保存,并用它来预测数据分类保存到磁盘scikit学习

我从scikit学习网站下面的示例程序:

from sklearn import datasets 
iris = datasets.load_iris() 
from sklearn.naive_bayes import GaussianNB 
gnb = GaussianNB() 
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data) 
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum() 

回答

123

分类器仅仅可以腌制和倾倒像任何其他对象。要继续例如:

import cPickle 
# save the classifier 
with open('my_dumped_classifier.pkl', 'wb') as fid: 
    cPickle.dump(gnb, fid)  

# load it again 
with open('my_dumped_classifier.pkl', 'rb') as fid: 
    gnb_loaded = cPickle.load(fid) 
+0

就像一个魅力!我试图使用np.savez并一直加载它,并从未帮助过。非常感谢。 – Kartos

156

您还可以使用joblib.dumpjoblib.load这是在处理数字阵列比默认的Python Pickler会更有效。

JOBLIB包括在scikit学习:

>>> from sklearn.externals import joblib 
>>> from sklearn.datasets import load_digits 
>>> from sklearn.linear_model import SGDClassifier 

>>> digits = load_digits() 
>>> clf = SGDClassifier().fit(digits.data, digits.target) 
>>> clf.score(digits.data, digits.target) # evaluate training error 
0.9526989426822482 

>>> filename = '/tmp/digits_classifier.joblib.pkl' 
>>> _ = joblib.dump(clf, filename, compress=9) 

>>> clf2 = joblib.load(filename) 
>>> clf2 
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0, 
     fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5, 
     n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0, 
     shuffle=False, verbose=0, warm_start=False) 
>>> clf2.score(digits.data, digits.target) 
0.9526989426822482 
+0

但从我的理解流水作品,如果它的一部分工作流。如果我想将模型存储在磁盘上并停止执行。然后我回来一个星期后,尝试从磁盘加载模型,它会抛出一个错误: – venuktan

+0

如果这是您正在查找的内容,则无法停止并恢复“fit”方法的执行。也就是说,如果您使用相同版本的scikit-learn库从Python中调用joblib.load,那么joblib.dump'成功后不应该引发异常。 – ogrisel

+7

如果您使用IPython,请不要使用'--pylab'命令行标志或'%pylab'魔术,因为已知隐式命名空间重载会中断酸洗过程。相反,请使用显式导入和'%matplotlib inline'魔术。 – ogrisel

49

你要找的是所谓模型持久在sklearn的话,它是在introductionmodel persistence部分记录。

所以,你必须初始化你的分类,并与

clf = some.classifier() 
clf.fit(X, y) 

之后训练的很长一段时间,你有两种选择:

1)用泡椒

import pickle 
# now you can save it to a file 
with open('filename.pkl', 'wb') as f: 
    pickle.dump(clf, f) 

# and later you can load it 
with open('filename.pkl', 'rb') as f: 
    clf = pickle.load(f) 

2)使用Joblib

from sklearn.externals import joblib 
# now you can save it to a file 
joblib.dump(clf, 'filename.pkl') 
# and later you can load it 
clf = joblib.load('filename.pkl') 

一个更多的时间是有帮助的阅读上述链接

5

在许多情况下,尤其是文本分类是不足够的存储分类,但你需要存储的矢量化,以及以便将来可以矢量化您的输入。

import pickle 
with open('model.pkl', 'wb') as fout: 
    pickle.dump((vectorizer, clf), fout) 

将来使用情况:

with open('model.pkl', 'rb') as fin: 
    vectorizer, clf = pickle.load(fin) 

X_new = vectorizer.transform(new_samples) 
X_new_preds = clf.predict(X_new) 

在转储向量化,可以通过删除矢量化的stop_words_属性:

vectorizer.stop_words_ = None 

作出倾销更为有效。 此外,如果您的分类器参数是稀疏的(如在大多数文本分类示例中),则可以将参数从密集转换为稀疏,这会在内存消耗,加载和转储方面产生巨大差异。 Sparsify由模型:

clf.sparsify() 

,它会自动为SGDClassifier工作,但如果你知道你的模型是稀疏(地段clf.coef_零),那么你可以手动转换CLF。coef_通过:

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_) 

然后你可以更有效地存储它。

1

sklearn估计器实现方法,使您可以轻松保存估算器的相关训练属性。一些估计实现__getstate__方法本身,而是其他人,像GMM只是使用它只是保存对象内部字典中的base implementation

def __getstate__(self): 
    try: 
     state = super(BaseEstimator, self).__getstate__() 
    except AttributeError: 
     state = self.__dict__.copy() 

    if type(self).__module__.startswith('sklearn.'): 
     return dict(state.items(), _sklearn_version=__version__) 
    else: 
     return state 

推荐的方法来保存你的模型到光盘是使用pickle模块:

from sklearn import datasets 
from sklearn.svm import SVC 
iris = datasets.load_iris() 
X = iris.data[:100, :2] 
y = iris.target[:100] 
model = SVC() 
model.fit(X,y) 
import pickle 
with open('mymodel','wb') as f: 
    pickle.dump(model,f) 

但是,您应该保存额外的数据,以便将来可以重新训练模型,或遭受可怕的后果(例如锁定到旧版sklearn)

documentation

训练数据,例如:

为了重建与 未来版本scikit学习一个类似的模型,额外的元数据应该沿着腌制 模型保存得分训练数据获得一个不变的快照

用于生成模型

的scikit学习的版本和它的依赖

交叉验证的蟒源代码的引用

对于依赖于用Cython编写的tree.pyx模块(如IsolationForest)的合成估计器尤其如此,因为它创建了一个耦合到实现在sklearn版本之间不保证稳定。它在过去看到了倒退不相容的变化。

如果您的模型变得非常大并且加载变得麻烦,您还可以使用效率更高的joblib。从文档:

在scikit的特定情况下,它可能是更有趣使用 JOBLIB的更换picklejoblib.dump & joblib.load),这是 上在内部作为携带大numpy的数组的对象更高效 往往是安装scikit学习估计的情况下,却只能 咸菜到磁盘,而不是一个字符串: