aboutsummaryrefslogtreecommitdiff
path: root/sklearn_pandas/cross_validation.py
diff options
context:
space:
mode:
authorChristopher Baines <mail@cbaines.net>2015-12-13 16:20:50 +0000
committerChristopher Baines <mail@cbaines.net>2015-12-13 16:20:50 +0000
commit31d70519b84ea5d4b6df194d6f251ace6bc74ffc (patch)
tree25561e8ac7b2faa9dc3a7a72a224050f1d74f99f /sklearn_pandas/cross_validation.py
parent147d916d9cc641d496b8bbb32b7db99701038491 (diff)
downloadsklearn-pandas-31d70519b84ea5d4b6df194d6f251ace6bc74ffc.tar
sklearn-pandas-31d70519b84ea5d4b6df194d6f251ace6bc74ffc.tar.gz
Imported Upstream version 1.1.0upstream/1.1.0upstream
Diffstat (limited to 'sklearn_pandas/cross_validation.py')
-rw-r--r--sklearn_pandas/cross_validation.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/sklearn_pandas/cross_validation.py b/sklearn_pandas/cross_validation.py
new file mode 100644
index 0000000..9cd8cbe
--- /dev/null
+++ b/sklearn_pandas/cross_validation.py
@@ -0,0 +1,37 @@
+from sklearn import cross_validation
+from sklearn import grid_search
+
+
+def cross_val_score(model, X, *args, **kwargs):
+ X = DataWrapper(X)
+ return cross_validation.cross_val_score(model, X, *args, **kwargs)
+
+
+class GridSearchCV(grid_search.GridSearchCV):
+ def fit(self, X, *params, **kwparams):
+ return super(GridSearchCV, self).fit(DataWrapper(X), *params, **kwparams)
+
+ def predict(self, X, *params, **kwparams):
+ return super(GridSearchCV, self).predict(DataWrapper(X), *params, **kwparams)
+
+
+try:
+ class RandomizedSearchCV(grid_search.RandomizedSearchCV):
+ def fit(self, X, *params, **kwparams):
+ return super(RandomizedSearchCV, self).fit(DataWrapper(X), *params, **kwparams)
+
+ def predict(self, X, *params, **kwparams):
+ return super(RandomizedSearchCV, self).predict(DataWrapper(X), *params, **kwparams)
+except AttributeError:
+ pass
+
+
+class DataWrapper(object):
+ def __init__(self, df):
+ self.df = df
+
+ def __len__(self):
+ return len(self.df)
+
+ def __getitem__(self, key):
+ return self.df.iloc[key]