diff options
author | Christopher Baines <mail@cbaines.net> | 2015-12-13 16:20:50 +0000 |
---|---|---|
committer | Christopher Baines <mail@cbaines.net> | 2015-12-13 16:20:50 +0000 |
commit | 31d70519b84ea5d4b6df194d6f251ace6bc74ffc (patch) | |
tree | 25561e8ac7b2faa9dc3a7a72a224050f1d74f99f /sklearn_pandas/cross_validation.py | |
parent | 147d916d9cc641d496b8bbb32b7db99701038491 (diff) | |
download | sklearn-pandas-upstream.tar sklearn-pandas-upstream.tar.gz |
Imported Upstream version 1.1.0upstream/1.1.0upstream
Diffstat (limited to 'sklearn_pandas/cross_validation.py')
-rw-r--r-- | sklearn_pandas/cross_validation.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/sklearn_pandas/cross_validation.py b/sklearn_pandas/cross_validation.py new file mode 100644 index 0000000..9cd8cbe --- /dev/null +++ b/sklearn_pandas/cross_validation.py @@ -0,0 +1,37 @@ +from sklearn import cross_validation +from sklearn import grid_search + + +def cross_val_score(model, X, *args, **kwargs): + X = DataWrapper(X) + return cross_validation.cross_val_score(model, X, *args, **kwargs) + + +class GridSearchCV(grid_search.GridSearchCV): + def fit(self, X, *params, **kwparams): + return super(GridSearchCV, self).fit(DataWrapper(X), *params, **kwparams) + + def predict(self, X, *params, **kwparams): + return super(GridSearchCV, self).predict(DataWrapper(X), *params, **kwparams) + + +try: + class RandomizedSearchCV(grid_search.RandomizedSearchCV): + def fit(self, X, *params, **kwparams): + return super(RandomizedSearchCV, self).fit(DataWrapper(X), *params, **kwparams) + + def predict(self, X, *params, **kwparams): + return super(RandomizedSearchCV, self).predict(DataWrapper(X), *params, **kwparams) +except AttributeError: + pass + + +class DataWrapper(object): + def __init__(self, df): + self.df = df + + def __len__(self): + return len(self.df) + + def __getitem__(self, key): + return self.df.iloc[key] |