1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
import six
from sklearn.pipeline import _name_estimators, Pipeline
from sklearn.utils import tosequence
class TransformerPipeline(Pipeline):
"""
Pipeline that expects all steps to be transformers taking a single argument
and having fit and transform methods.
Code is copied from sklearn's Pipeline, leaving out the `y=None` argument.
"""
def __init__(self, steps):
names, estimators = zip(*steps)
if len(dict(steps)) != len(steps):
raise ValueError("Provided step names are not unique: %s" % (names,))
# shallow copy of steps
self.steps = tosequence(steps)
estimator = estimators[-1]
for e in estimators:
if (not (hasattr(e, "fit") or hasattr(e, "fit_transform")) or not
hasattr(e, "transform")):
raise TypeError("All steps of the chain should "
"be transforms and implement fit and transform"
" '%s' (type %s) doesn't)" % (e, type(e)))
if not hasattr(estimator, "fit"):
raise TypeError("Last step of chain should implement fit "
"'%s' (type %s) doesn't)"
% (estimator, type(estimator)))
def _pre_transform(self, X, **fit_params):
fit_params_steps = dict((step, {}) for step, _ in self.steps)
for pname, pval in six.iteritems(fit_params):
step, param = pname.split('__', 1)
fit_params_steps[step][param] = pval
Xt = X
for name, transform in self.steps[:-1]:
if hasattr(transform, "fit_transform"):
Xt = transform.fit_transform(Xt, **fit_params_steps[name])
else:
Xt = transform.fit(Xt, **fit_params_steps[name]) \
.transform(Xt)
return Xt, fit_params_steps[self.steps[-1][0]]
def fit(self, X, **fit_params):
Xt, fit_params = self._pre_transform(X, **fit_params)
self.steps[-1][-1].fit(Xt, **fit_params)
return self
def fit_transform(self, X, **fit_params):
Xt, fit_params = self._pre_transform(X, **fit_params)
if hasattr(self.steps[-1][-1], 'fit_transform'):
return self.steps[-1][-1].fit_transform(Xt, **fit_params)
else:
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt)
def make_transformer_pipeline(*steps):
"""Construct a TransformerPipeline from the given estimators.
"""
return TransformerPipeline(_name_estimators(steps))
|