2016-10-03 211 views
0

我想使用LeaveOneGroupOut策略来评估我的模型。根据sklearn's documentcross_val_score似乎很方便。在sklearn的cross_val_score上使用LeaveOneGroupOut策略

但是,下面的代码不起作用。

import sklearn 
from sklearn import datasets 
iris = datasets.load_iris() 
from sklearn.model_selection import cross_val_score 
clf = sklearn.svm.SVC(kernel='linear', C=1) 
# cv = ShuffleSplit(n_splits=3, test_size=0.3, random_state=0) # => this works 
cv = LeaveOneGroupOut # => this does not work 
scores = cross_val_score(clf, iris.data, iris.target, cv=cv) 

的错误信息是:

ValueError        Traceback (most recent call last) 
<ipython-input-40-435a3a7fa16c> in <module>() 
     4 from sklearn.model_selection import cross_val_score 
     5 clf = sklearn.svm.SVC(kernel='linear', C=1) 
----> 6 scores = cross_val_score(clf, iris.data, iris.target, cv=LeaveOneGroupOut()) 
     7 scores 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/model_selection/_validation.pyc in cross_val_score(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch) 
    138            train, test, verbose, None, 
    139            fit_params) 
--> 140      for train, test in cv.split(X, y, groups)) 
    141  return np.array(scores)[:, 0] 
    142 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable) 
    756    # was dispatched. In particular this covers the edge 
    757    # case of Parallel used with an exhausted iterator. 
--> 758    while self.dispatch_one_batch(iterator): 
    759     self._iterating = True 
    760    else: 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in dispatch_one_batch(self, iterator) 
    601 
    602   with self._lock: 
--> 603    tasks = BatchedCalls(itertools.islice(iterator, batch_size)) 
    604    if len(tasks) == 0: 
    605     # No more tasks available in the iterator: tell caller to stop. 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __init__(self, iterator_slice) 
    125 
    126  def __init__(self, iterator_slice): 
--> 127   self.items = list(iterator_slice) 
    128   self._size = len(self.items) 
    129 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/model_selection/_validation.pyc in <genexpr>(***failed resolving arguments***) 
    135  parallel = Parallel(n_jobs=n_jobs, verbose=verbose, 
    136       pre_dispatch=pre_dispatch) 
--> 137  scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, 
    138            train, test, verbose, None, 
    139            fit_params) 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/model_selection/_split.pyc in split(self, X, y, groups) 
    88   X, y, groups = indexable(X, y, groups) 
    89   indices = np.arange(_num_samples(X)) 
---> 90   for test_index in self._iter_test_masks(X, y, groups): 
    91    train_index = indices[np.logical_not(test_index)] 
    92    test_index = indices[test_index] 

/Users/xxx/.pyenv/versions/anaconda-2.0.1/lib/python2.7/site-packages/sklearn/model_selection/_split.pyc in _iter_test_masks(self, X, y, groups) 
    770  def _iter_test_masks(self, X, y, groups): 
    771   if groups is None: 
--> 772    raise ValueError("The groups parameter should not be None") 
    773   # We make a copy of groups to avoid side-effects during iteration 
    774   groups = np.array(groups, copy=True) 

ValueError: The groups parameter should not be None 
scores 

回答

2

你不定义参数,它是根据你所要拆分您的数据组。

这个错误来自cross_val_score是发生在论证这个参数:在你的情况下,它等于

尝试按照下面的例子:

from sklearn.model_selection import LeaveOneGroupOut 
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) 
y = np.array([1, 2, 1, 2]) 
groups = np.array([1, 1, 2, 2]) 
lol = LeaveOneGroupOut() 

您有:

[In] lol.get_n_splits(X, y, groups) 
[Out] 2 

然后你就可以使用:

lol.split(X, y, groups) 
+0

可以用cross_val_score不LeaveOneGroupOut工作? – rkjt50r983

+1

@ rkjt50r983定义'cv = LeaveOneGroupOut()。split(X,y,groups)',然后在'cross_val_score()'中使用'cv = cv'。 – Michael