]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allowed column types to be callables. Fixes #1165.
authorMichael Trier <mtrier@gmail.com>
Sat, 4 Oct 2008 01:49:14 +0000 (01:49 +0000)
committerMichael Trier <mtrier@gmail.com>
Sat, 4 Oct 2008 01:49:14 +0000 (01:49 +0000)
lib/sqlalchemy/schema.py
lib/sqlalchemy/types.py
test/sql/testtypes.py

index bb97943ebe004c868107cd04aac6f033dd1a191a..d859b9061090cd2462bdd7a314bd0328e21d29a0 100644 (file)
@@ -547,9 +547,13 @@ class Column(SchemaItem, expression._ColumnClause):
                         "May not pass name positionally and as a keyword.")
                 name = args.pop(0)
         if args:
-            if (isinstance(args[0], types.AbstractType) or
-                (isinstance(args[0], type) and
-                 issubclass(args[0], types.AbstractType))):
+            coltype = args[0]
+            if callable(coltype):
+                coltype = args[0]()
+
+            if (isinstance(coltype, types.AbstractType) or
+                (isinstance(coltype, type) and
+                 issubclass(coltype, types.AbstractType))):
                 if type_ is not None:
                     raise exc.ArgumentError(
                         "May not pass type_ positionally and as a keyword.")
index a7243f279460af36c228c13a28d9880491bf6e2b..fe6cc7b534e7febe9ddca8cba6753cca6468fde3 100644 (file)
@@ -246,9 +246,10 @@ class MutableType(object):
 def to_instance(typeobj):
     if typeobj is None:
         return NULLTYPE
-    elif isinstance(typeobj, type):
+
+    try: 
         return typeobj()
-    else:
+    except TypeError:
         return typeobj
 
 def adapt_type(typeobj, colspecs):
index 2ed8f37d757a194149e0d20b469fb775d37a4208..c8b9d7f394c644b799cf589797ff82b1d90c7967 100644 (file)
@@ -797,5 +797,29 @@ class BooleanTest(TestBase, AssertsExecutionResults):
         print res2
         assert(res2==[(2, False)])
 
+class CallableTest(TestBase, AssertsExecutionResults):
+    def setUpAll(self):
+        global meta
+        meta = MetaData(testing.db)
+
+    def tearDownAll(self):
+        meta.drop_all()
+
+    def test_callable_as_arg(self):
+        from functools import partial
+        ucode = partial(Unicode, assert_unicode=None)
+
+        thing_table = Table('thing', meta,
+            Column('name', ucode, primary_key=True)
+        )
+
+    def test_callable_as_kwarg(self):
+        from functools import partial
+        ucode = partial(Unicode, assert_unicode=None)
+
+        thang_table = Table('thang', meta,
+            Column('name', type_=ucode, primary_key=True)
+        )
+
 if __name__ == "__main__":
     testenv.main()