diff options
Diffstat (limited to 'sklearn_pandas/cross_validation.py')
-rw-r--r-- | sklearn_pandas/cross_validation.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/sklearn_pandas/cross_validation.py b/sklearn_pandas/cross_validation.py new file mode 100644 index 0000000..9cd8cbe --- /dev/null +++ b/sklearn_pandas/cross_validation.py @@ -0,0 +1,37 @@ +from sklearn import cross_validation +from sklearn import grid_search + + +def cross_val_score(model, X, *args, **kwargs): + X = DataWrapper(X) + return cross_validation.cross_val_score(model, X, *args, **kwargs) + + +class GridSearchCV(grid_search.GridSearchCV): + def fit(self, X, *params, **kwparams): + return super(GridSearchCV, self).fit(DataWrapper(X), *params, **kwparams) + + def predict(self, X, *params, **kwparams): + return super(GridSearchCV, self).predict(DataWrapper(X), *params, **kwparams) + + +try: + class RandomizedSearchCV(grid_search.RandomizedSearchCV): + def fit(self, X, *params, **kwparams): + return super(RandomizedSearchCV, self).fit(DataWrapper(X), *params, **kwparams) + + def predict(self, X, *params, **kwparams): + return super(RandomizedSearchCV, self).predict(DataWrapper(X), *params, **kwparams) +except AttributeError: + pass + + +class DataWrapper(object): + def __init__(self, df): + self.df = df + + def __len__(self): + return len(self.df) + + def __getitem__(self, key): + return self.df.iloc[key] |