summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/declarations.py113
1 files changed, 113 insertions, 0 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 41d99a3..d9d560a 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -250,6 +250,119 @@ class SubFactory(OrderedDeclaration):
return self.factory.build(**attrs)
+class PostGenerationDeclaration(object):
+ """Declarations to be called once the target object has been generated.
+
+ Attributes:
+ extract_prefix (str): prefix to use when extracting attributes from
+ the factory's declaration for this declaration. If empty, uses
+ the attribute name of the PostGenerationDeclaration.
+ """
+
+ def __init__(self, extract_prefix=None):
+ self.extract_prefix = extract_prefix
+
+ def extract(self, name, attrs):
+ """Extract relevant attributes from a dict.
+
+ Args:
+ name (str): the name at which this PostGenerationDeclaration was
+ defined in the declarations
+ attrs (dict): the attribute dict from which values should be
+ extracted
+
+ Returns:
+ (object, dict): a tuple containing the attribute at 'name' (if
+ provided) and a dict of extracted attributes
+ """
+ extracted = attrs.get(name)
+ if self.extract_prefix:
+ extract_prefix = self.extract_prefix
+ else:
+ extract_prefix = name
+ kwargs = utils.extract_dict(extract_prefix)
+ return extracted, kwargs
+
+ def call(self, obj, create, extracted=None, **kwargs):
+ """Call this hook; no return value is expected.
+
+ Args:
+ obj (object): the newly generated object
+ create (bool): whether the object was 'built' or 'created'
+ extracted (object): the value given for <extract_prefix> in the
+ object definition, or None if not provided.
+ kwargs (dict): declarations extracted from the object
+ definition for this hook
+ """
+ raise NotImplementedError()
+
+
+class PostGeneration(PostGenerationDeclaration):
+ """Calls a given function once the object has been generated."""
+ def __init__(self, function, extract_prefix=None):
+ super(PostGeneration, self).__init__(extract_prefix)
+ self.function = function
+
+ def call(self, obj, create, extracted=None, **kwargs):
+ self.function(obj, create, extracted, **kwargs)
+
+
+def post_declaration(extract_prefix=None):
+ def decorator(fun):
+ return PostGeneration(fun, extract_prefix=extract_prefix)
+ return decorator
+
+
+class RelatedFactory(PostGenerationDeclaration):
+ """Calls a factory once the object has been generated.
+
+ Attributes:
+ factory (Factory): the factory to call
+ defaults (dict): extra declarations for calling the related factory
+ name (str): the name to use to refer to the generated object when
+ calling the related factory
+ """
+
+ def __init__(self, factory, name='', **defaults):
+ super(RelatedFactory, self).__init__(extract_prefix=None)
+ self.factory = factory
+ self.name = name
+ self.defaults = defaults
+
+ def call(self, obj, create, extracted=None, **kwargs):
+ passed_kwargs = dict(self.defaults)
+ passed_kwargs.update(kwargs)
+ if self.name:
+ passed_kwargs[self.name] = obj
+ self.factory.simple_generate(create, **passed_kwargs)
+
+
+class PostGenerationMethodCall(PostGenerationDeclaration):
+ """Calls a method of the generated object.
+
+ Attributes:
+ method_name (str): the method to call
+ method_args (list): arguments to pass to the method
+ method_kwargs (dict): keyword arguments to pass to the method
+
+ Example:
+ class UserFactory(factory.Factory):
+ ...
+ password = factory.PostGenerationMethodCall('set_password', password='')
+ """
+ def __init__(self, method_name, extract_prefix=None, *args, **kwargs):
+ super(RelatedFactory, self).__init__(extract_prefix)
+ self.method_name = method_name
+ self.method_args = args
+ self.method_kwargs = kwargs
+
+ def call(self, obj, create, extracted=None, **kwargs):
+ passed_kwargs = dict(self.method_kwargs)
+ passed_kwargs.update(kwargs)
+ method = getattr(obj, self.method_name)
+ method(*self.method_args, **passed_kwargs)
+
+
# Decorators... in case lambdas don't cut it
def lazy_attribute(func):