summaryrefslogtreecommitdiff
path: root/factory/base.py
diff options
context:
space:
mode:
authorminimumserious <commande.romain@gmail.com>2013-06-02 14:59:31 +0200
committerRaphaƫl Barrois <raphael.barrois@polytechnique.org>2013-06-20 01:40:14 +0200
commit4bf58811b8a4aa79564afa2ac221306821d5c3d1 (patch)
treee12a79ae3368cf3f20aba5fa6e723ccd4b5d8471 /factory/base.py
parentae3e8e67490a80c5f7cf405fe8ee7cddedb8c7a9 (diff)
downloadfactory-boy-4bf58811b8a4aa79564afa2ac221306821d5c3d1.tar
factory-boy-4bf58811b8a4aa79564afa2ac221306821d5c3d1.tar.gz
Added SQLAlchemy support
Diffstat (limited to 'factory/base.py')
-rw-r--r--factory/base.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/factory/base.py b/factory/base.py
index 0429231..029185b 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -635,3 +635,36 @@ def use_strategy(new_strategy):
klass.FACTORY_STRATEGY = new_strategy
return klass
return wrapped_class
+
+
+class SQLAlchemyModelFactory(Factory):
+ """Factory for SQLAlchemy models. """
+
+ ABSTRACT_FACTORY = True
+ FACTORY_HIDDEN_ARGS=('SESSION',)
+
+ def __init__(self, session):
+ self.session = session
+
+ @classmethod
+ def _get_function(cls, function_name):
+ session = cls._declarations['SESSION']
+ sqlalchemy = __import__(session.__module__)
+ max = getattr(sqlalchemy.sql.functions, function_name)
+
+ @classmethod
+ def _setup_next_sequence(cls, *args, **kwargs):
+ """Compute the next available PK, based on the 'pk' database field."""
+ max = cls._get_function('max')
+ session = cls._declarations['SESSION']
+ pk = cls.FACTORY_FOR.__table__.primary_key.columns.values()[0].key
+ max_pk = session.query(max(getattr(cls.FACTORY_FOR, pk))).one()
+ return max_pk[0] + 1 if max_pk[0] else 1
+
+ @classmethod
+ def _create(cls, target_class, *args, **kwargs):
+ """Create an instance of the model, and save it to the database."""
+ session = cls._declarations['SESSION']
+ obj = target_class(*args, **kwargs)
+ session.add(obj)
+ return obj