aboutsummaryrefslogtreecommitdiff
path: root/sklearn_pandas/cross_validation.py
diff options
context:
space:
mode:
Diffstat (limited to 'sklearn_pandas/cross_validation.py')
-rw-r--r--sklearn_pandas/cross_validation.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/sklearn_pandas/cross_validation.py b/sklearn_pandas/cross_validation.py
new file mode 100644
index 0000000..9cd8cbe
--- /dev/null
+++ b/sklearn_pandas/cross_validation.py
@@ -0,0 +1,37 @@
+from sklearn import cross_validation
+from sklearn import grid_search
+
+
+def cross_val_score(model, X, *args, **kwargs):
+ X = DataWrapper(X)
+ return cross_validation.cross_val_score(model, X, *args, **kwargs)
+
+
+class GridSearchCV(grid_search.GridSearchCV):
+ def fit(self, X, *params, **kwparams):
+ return super(GridSearchCV, self).fit(DataWrapper(X), *params, **kwparams)
+
+ def predict(self, X, *params, **kwparams):
+ return super(GridSearchCV, self).predict(DataWrapper(X), *params, **kwparams)
+
+
+try:
+ class RandomizedSearchCV(grid_search.RandomizedSearchCV):
+ def fit(self, X, *params, **kwparams):
+ return super(RandomizedSearchCV, self).fit(DataWrapper(X), *params, **kwparams)
+
+ def predict(self, X, *params, **kwparams):
+ return super(RandomizedSearchCV, self).predict(DataWrapper(X), *params, **kwparams)
+except AttributeError:
+ pass
+
+
+class DataWrapper(object):
+ def __init__(self, df):
+ self.df = df
+
+ def __len__(self):
+ return len(self.df)
+
+ def __getitem__(self, key):
+ return self.df.iloc[key]