diff options
author | Raphaël Barrois <raphael.barrois@polyconseil.fr> | 2012-04-13 17:51:01 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polyconseil.fr> | 2012-04-15 10:05:19 +0200 |
commit | 3ba8ed544fa9e866f97efc41155ee296f022e9b4 (patch) | |
tree | 61dbeb8c0e34971c868181427757a7bdb87c83ac | |
parent | 184dd0516267c58370d6a88afb1c1ce894b2b7c1 (diff) | |
download | factory-boy-3ba8ed544fa9e866f97efc41155ee296f022e9b4.tar factory-boy-3ba8ed544fa9e866f97efc41155ee296f022e9b4.tar.gz |
Add basis for PostGenerationDeclaration
Signed-off-by: Raphaël Barrois <raphael.barrois@polyconseil.fr>
-rw-r--r-- | factory/declarations.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 41d99a3..d9d560a 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -250,6 +250,119 @@ class SubFactory(OrderedDeclaration): return self.factory.build(**attrs) +class PostGenerationDeclaration(object): + """Declarations to be called once the target object has been generated. + + Attributes: + extract_prefix (str): prefix to use when extracting attributes from + the factory's declaration for this declaration. If empty, uses + the attribute name of the PostGenerationDeclaration. + """ + + def __init__(self, extract_prefix=None): + self.extract_prefix = extract_prefix + + def extract(self, name, attrs): + """Extract relevant attributes from a dict. + + Args: + name (str): the name at which this PostGenerationDeclaration was + defined in the declarations + attrs (dict): the attribute dict from which values should be + extracted + + Returns: + (object, dict): a tuple containing the attribute at 'name' (if + provided) and a dict of extracted attributes + """ + extracted = attrs.get(name) + if self.extract_prefix: + extract_prefix = self.extract_prefix + else: + extract_prefix = name + kwargs = utils.extract_dict(extract_prefix) + return extracted, kwargs + + def call(self, obj, create, extracted=None, **kwargs): + """Call this hook; no return value is expected. + + Args: + obj (object): the newly generated object + create (bool): whether the object was 'built' or 'created' + extracted (object): the value given for <extract_prefix> in the + object definition, or None if not provided. + kwargs (dict): declarations extracted from the object + definition for this hook + """ + raise NotImplementedError() + + +class PostGeneration(PostGenerationDeclaration): + """Calls a given function once the object has been generated.""" + def __init__(self, function, extract_prefix=None): + super(PostGeneration, self).__init__(extract_prefix) + self.function = function + + def call(self, obj, create, extracted=None, **kwargs): + self.function(obj, create, extracted, **kwargs) + + +def post_declaration(extract_prefix=None): + def decorator(fun): + return PostGeneration(fun, extract_prefix=extract_prefix) + return decorator + + +class RelatedFactory(PostGenerationDeclaration): + """Calls a factory once the object has been generated. + + Attributes: + factory (Factory): the factory to call + defaults (dict): extra declarations for calling the related factory + name (str): the name to use to refer to the generated object when + calling the related factory + """ + + def __init__(self, factory, name='', **defaults): + super(RelatedFactory, self).__init__(extract_prefix=None) + self.factory = factory + self.name = name + self.defaults = defaults + + def call(self, obj, create, extracted=None, **kwargs): + passed_kwargs = dict(self.defaults) + passed_kwargs.update(kwargs) + if self.name: + passed_kwargs[self.name] = obj + self.factory.simple_generate(create, **passed_kwargs) + + +class PostGenerationMethodCall(PostGenerationDeclaration): + """Calls a method of the generated object. + + Attributes: + method_name (str): the method to call + method_args (list): arguments to pass to the method + method_kwargs (dict): keyword arguments to pass to the method + + Example: + class UserFactory(factory.Factory): + ... + password = factory.PostGenerationMethodCall('set_password', password='') + """ + def __init__(self, method_name, extract_prefix=None, *args, **kwargs): + super(RelatedFactory, self).__init__(extract_prefix) + self.method_name = method_name + self.method_args = args + self.method_kwargs = kwargs + + def call(self, obj, create, extracted=None, **kwargs): + passed_kwargs = dict(self.method_kwargs) + passed_kwargs.update(kwargs) + method = getattr(obj, self.method_name) + method(*self.method_args, **passed_kwargs) + + # Decorators... in case lambdas don't cut it def lazy_attribute(func): |