]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- further fix to new TypeDecorator, so that subclasses of TypeDecorators work properly
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 2 Jan 2008 01:29:38 +0000 (01:29 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 2 Jan 2008 01:29:38 +0000 (01:29 +0000)
- _handle_dbapi_exception() usage changed so that unwrapped exceptions can be rethrown with the original stack trace

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/types.py
test/sql/testtypes.py

index e3433a06be9f037ff0e996cc84f54d40f3b07255..f2a1cd286443ff0243270c1f76e6065e4963a775 100644 (file)
@@ -732,7 +732,8 @@ class Connection(Connectable):
         try:
             self.engine.dialect.do_begin(self.connection)
         except Exception, e:
-            raise self._handle_dbapi_exception(e, None, None, None)
+            self._handle_dbapi_exception(e, None, None, None)
+            raise
 
     def _rollback_impl(self):
         if not self.closed and not self.invalidated and self.__connection.is_valid:
@@ -742,7 +743,8 @@ class Connection(Connectable):
                 self.engine.dialect.do_rollback(self.connection)
                 self.__transaction = None
             except Exception, e:
-                raise self._handle_dbapi_exception(e, None, None, None)
+                self._handle_dbapi_exception(e, None, None, None)
+                raise
         else:
             self.__transaction = None
 
@@ -753,7 +755,8 @@ class Connection(Connectable):
             self.engine.dialect.do_commit(self.connection)
             self.__transaction = None
         except Exception, e:
-            raise self._handle_dbapi_exception(e, None, None, None)
+            self._handle_dbapi_exception(e, None, None, None)
+            raise
         
     def _savepoint_impl(self, name=None):
         if name is None:
@@ -914,11 +917,11 @@ class Connection(Connectable):
 
     def _handle_dbapi_exception(self, e, statement, parameters, cursor):
         if getattr(self, '_reentrant_error', False):
-            return exceptions.DBAPIError.instance(None, None, e)
+            raise exceptions.DBAPIError.instance(None, None, e)
         self._reentrant_error = True
         try:
             if not isinstance(e, self.dialect.dbapi.Error):
-                return e
+                return
             is_disconnect = self.dialect.is_disconnect(e)
             if is_disconnect:
                 self.invalidate(e)
@@ -929,7 +932,7 @@ class Connection(Connectable):
                 self._autorollback()
                 if self.__close_with_result:
                     self.close()
-            return exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+            raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
         finally:
             del self._reentrant_error
         
@@ -937,7 +940,8 @@ class Connection(Connectable):
         try:
             return self.engine.dialect.create_execution_context(connection=self, **kwargs)
         except Exception, e:
-            raise self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
+            self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
+            raise
 
     def _cursor_execute(self, cursor, statement, parameters, context=None):
         if self.engine._should_log_info:
@@ -946,7 +950,8 @@ class Connection(Connectable):
         try:
             self.dialect.do_execute(cursor, statement, parameters, context=context)
         except Exception, e:
-            raise self._handle_dbapi_exception(e, statement, parameters, cursor)
+            self._handle_dbapi_exception(e, statement, parameters, cursor)
+            raise
 
     def _cursor_executemany(self, cursor, statement, parameters, context=None):
         if self.engine._should_log_info:
@@ -955,7 +960,8 @@ class Connection(Connectable):
         try:
             self.dialect.do_executemany(cursor, statement, parameters, context=context)
         except Exception, e:
-            raise self._handle_dbapi_exception(e, statement, parameters, cursor)
+            self._handle_dbapi_exception(e, statement, parameters, cursor)
+            raise
 
     # poor man's multimethod/generic function thingy
     executors = {
index 16d55e5b8ea7bfa3909e835f9c0073b658a0189a..e78eedd5c8c465bc790db705bfdc8e582cad117f 100644 (file)
@@ -337,7 +337,8 @@ class DefaultExecutionContext(base.ExecutionContext):
             try:
                 self.cursor.setinputsizes(*inputsizes)
             except Exception, e:
-                raise self._connection._handle_dbapi_exception(e, None, None, None)
+                self._connection._handle_dbapi_exception(e, None, None, None)
+                raise
         else:
             inputsizes = {}
             for key in self.compiled.bind_names.values():
@@ -348,7 +349,8 @@ class DefaultExecutionContext(base.ExecutionContext):
             try:
                 self.cursor.setinputsizes(**inputsizes)
             except Exception, e:
-                raise self._connection._handle_dbapi_exception(e, None, None, None)
+                self._connection._handle_dbapi_exception(e, None, None, None)
+                raise
 
     def __process_defaults(self):
         """generate default values for compiled insert/update statements,
index cf6d147140a48bff99f6bd79f132b4fbe99e0896..9e8b0f488076914e1c64a87a69c30d4220c2d637 100644 (file)
@@ -218,8 +218,8 @@ class DefaultCompiler(engine.Compiled):
             return pd
         else:
             return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names])
-
-    params = property(lambda self:self.construct_params(), doc="""return a dictionary of bind parameter keys and values""")
+    
+    params = property(construct_params)
         
     def default_from(self):
         """Called when a SELECT statement has no froms, and no FROM clause is to be appended.
index 14262d6e06e0a49e9397d9d49748b9da1d3535e1..5ab9ad45095c37c5f6a868fcd128f0a8a9a5a70c 100644 (file)
@@ -239,7 +239,7 @@ class TypeDecorator(AbstractType):
         raise NotImplementedError()
         
     def bind_processor(self, dialect):
-        if 'process_bind_param' in self.__class__.__dict__:
+        if self.__class__.process_bind_param.func_code is not TypeDecorator.process_bind_param.func_code:
             impl_processor = self.impl.bind_processor(dialect)
             if impl_processor:
                 def process(value):
@@ -253,7 +253,7 @@ class TypeDecorator(AbstractType):
             return self.impl.bind_processor(dialect)
 
     def result_processor(self, dialect):
-        if 'process_result_value' in self.__class__.__dict__:
+        if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code:
             impl_processor = self.impl.result_processor(dialect)
             if impl_processor:
                 def process(value):
index eeb4a373f3e02d9c35ba06c6d690ceb8ba2ed36d..fc1da5578b87ad9af4f4c526be20f39dacf6661e 100644 (file)
@@ -9,112 +9,6 @@ from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
 from testlib import *
 
 
-class MyType(types.TypeEngine):
-    def get_col_spec(self):
-        return "VARCHAR(100)"
-    def bind_processor(self, dialect):
-        def process(value):
-            return "BIND_IN"+ value
-        return process
-    def result_processor(self, dialect):
-        def process(value):
-            return value + "BIND_OUT"
-        return process
-    def adapt(self, typeobj):
-        return typeobj()
-
-class MyDecoratedType(types.TypeDecorator):
-    impl = String
-    def bind_processor(self, dialect):
-        impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value)
-        def process(value):
-            return "BIND_IN"+ impl_processor(value)
-        return process
-    def result_processor(self, dialect):
-        impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value)
-        def process(value):
-            return impl_processor(value) + "BIND_OUT"
-        return process
-    def copy(self):
-        return MyDecoratedType()
-
-class MyNewUnicodeType(types.TypeDecorator):
-    impl = Unicode
-
-    def process_bind_param(self, value, dialect):
-        return "BIND_IN" + value
-
-    def process_result_value(self, value, dialect):
-        return value + "BIND_OUT"
-
-    def copy(self):
-        return MyNewUnicodeType(self.impl.length)
-
-class MyNewIntType(types.TypeDecorator):
-    impl = Integer
-
-    def process_bind_param(self, value, dialect):
-        return value * 10
-
-    def process_result_value(self, value, dialect):
-        return value * 10
-
-    def copy(self):
-        return MyNewIntType()
-
-class MyUnicodeType(types.TypeDecorator):
-    impl = Unicode
-
-    def bind_processor(self, dialect):
-        impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value)
-        
-        def process(value):
-            return "BIND_IN"+ impl_processor(value)
-        return process
-
-    def result_processor(self, dialect):
-        impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value)
-        def process(value):
-            return impl_processor(value) + "BIND_OUT"
-        return process
-
-    def copy(self):
-        return MyUnicodeType(self.impl.length)
-
-class MyPickleType(types.TypeDecorator):
-    impl = PickleType
-    
-    def process_bind_param(self, value, dialect):
-        if value:
-            value.stuff = 'this is modified stuff'
-        return value
-    
-    def process_result_value(self, value, dialect):
-        if value:
-            value.stuff = 'this is the right stuff'
-        return value
-        
-class LegacyType(types.TypeEngine):
-    def get_col_spec(self):
-        return "VARCHAR(100)"
-    def convert_bind_param(self, value, dialect):
-        return "BIND_IN"+ value
-    def convert_result_value(self, value, dialect):
-        return value + "BIND_OUT"
-    def adapt(self, typeobj):
-        return typeobj()
-
-class LegacyUnicodeType(types.TypeDecorator):
-    impl = Unicode
-
-    def convert_bind_param(self, value, dialect):
-        return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect)
-
-    def convert_result_value(self, value, dialect):
-        return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT"
-
-    def copy(self):
-        return LegacyUnicodeType(self.impl.length)
 
 class AdaptTest(PersistTest):
     def testadapt(self):
@@ -149,6 +43,11 @@ class AdaptTest(PersistTest):
 
     def testoracletext(self):
         dialect = oracle.OracleDialect()
+        class MyDecoratedType(types.TypeDecorator):
+            impl = String
+            def copy(self):
+                return MyDecoratedType()
+            
         col = Column('', MyDecoratedType)
         dialect_type = col.type.dialect_impl(dialect)
         assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
@@ -215,25 +114,129 @@ class UserDefinedTest(PersistTest):
     def testprocessing(self):
 
         global users
-        users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12)
-        users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15)
-        users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9)
+        users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
+        users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
+        users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
 
         l = users.select().execute().fetchall()
-        for assertstr, assertint, row in zip(
+        for assertstr, assertint, assertint2, row in zip(
             ["BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", "BIND_INfredBIND_OUT"],
             [1200, 1500, 900],
+            [1800, 2250, 1350],
             l
             
         ):
             for col in row[1:8]:
                 self.assertEquals(col, assertstr)
             self.assertEquals(row[8], assertint)
+            self.assertEquals(row[9], assertint2)
             for col in (row[4], row[5], row[7]):
                 assert isinstance(col, unicode)
                 
     def setUpAll(self):
         global users, metadata
+
+        class MyType(types.TypeEngine):
+            def get_col_spec(self):
+                return "VARCHAR(100)"
+            def bind_processor(self, dialect):
+                def process(value):
+                    return "BIND_IN"+ value
+                return process
+            def result_processor(self, dialect):
+                def process(value):
+                    return value + "BIND_OUT"
+                return process
+            def adapt(self, typeobj):
+                return typeobj()
+
+        class MyDecoratedType(types.TypeDecorator):
+            impl = String
+            def bind_processor(self, dialect):
+                impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value)
+                def process(value):
+                    return "BIND_IN"+ impl_processor(value)
+                return process
+            def result_processor(self, dialect):
+                impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value)
+                def process(value):
+                    return impl_processor(value) + "BIND_OUT"
+                return process
+            def copy(self):
+                return MyDecoratedType()
+
+        class MyNewUnicodeType(types.TypeDecorator):
+            impl = Unicode
+
+            def process_bind_param(self, value, dialect):
+                return "BIND_IN" + value
+
+            def process_result_value(self, value, dialect):
+                return value + "BIND_OUT"
+
+            def copy(self):
+                return MyNewUnicodeType(self.impl.length)
+
+        class MyNewIntType(types.TypeDecorator):
+            impl = Integer
+
+            def process_bind_param(self, value, dialect):
+                return value * 10
+
+            def process_result_value(self, value, dialect):
+                return value * 10
+
+            def copy(self):
+                return MyNewIntType()
+
+        class MyNewIntSubClass(MyNewIntType):
+            def process_result_value(self, value, dialect):
+                return value * 15
+
+            def copy(self):
+                return MyNewIntSubClass()
+
+        class MyUnicodeType(types.TypeDecorator):
+            impl = Unicode
+
+            def bind_processor(self, dialect):
+                impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value)
+
+                def process(value):
+                    return "BIND_IN"+ impl_processor(value)
+                return process
+
+            def result_processor(self, dialect):
+                impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value)
+                def process(value):
+                    return impl_processor(value) + "BIND_OUT"
+                return process
+
+            def copy(self):
+                return MyUnicodeType(self.impl.length)
+
+        class LegacyType(types.TypeEngine):
+            def get_col_spec(self):
+                return "VARCHAR(100)"
+            def convert_bind_param(self, value, dialect):
+                return "BIND_IN"+ value
+            def convert_result_value(self, value, dialect):
+                return value + "BIND_OUT"
+            def adapt(self, typeobj):
+                return typeobj()
+
+        class LegacyUnicodeType(types.TypeDecorator):
+            impl = Unicode
+
+            def convert_bind_param(self, value, dialect):
+                return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect)
+
+            def convert_result_value(self, value, dialect):
+                return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT"
+
+            def copy(self):
+                return LegacyUnicodeType(self.impl.length)
+
         metadata = MetaData(testbase.db)
         users = Table('type_users', metadata,
             Column('user_id', Integer, primary_key = True),
@@ -251,6 +254,7 @@ class UserDefinedTest(PersistTest):
             Column('goofy6', LegacyType, nullable = False),
             Column('goofy7', MyNewUnicodeType, nullable = False),
             Column('goofy8', MyNewIntType, nullable = False),
+            Column('goofy9', MyNewIntSubClass, nullable = False),
 
         )
 
@@ -396,7 +400,21 @@ class UnicodeTest(AssertMixin):
 
 class BinaryTest(AssertMixin):
     def setUpAll(self):
-        global binary_table
+        global binary_table, MyPickleType
+
+        class MyPickleType(types.TypeDecorator):
+            impl = PickleType
+
+            def process_bind_param(self, value, dialect):
+                if value:
+                    value.stuff = 'this is modified stuff'
+                return value
+
+            def process_result_value(self, value, dialect):
+                if value:
+                    value.stuff = 'this is the right stuff'
+                return value
+
         binary_table = Table('binary_table', MetaData(testbase.db),
         Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
         Column('data', Binary),