]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- jython support. works OK for expressions, there's a major weakref bug in ORM tho
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Mar 2009 22:32:36 +0000 (22:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Mar 2009 22:32:36 +0000 (22:32 +0000)
- reraises of exceptions pass along the original stack trace

17 files changed:
lib/sqlalchemy/connectors/zxJDBC.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/dialects/mysql/zxjdbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/postgres/zxjdbc.py [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/orm/attributes.py
test/dialect/postgres.py
test/orm/query.py
test/sql/query.py
test/testlib/config.py
test/testlib/testing.py

index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..eb5c95b1607ccb540af3accd6e3c48a4b4cd8ae0 100644 (file)
@@ -0,0 +1,45 @@
+from sqlalchemy.connectors import Connector
+
+import sys
+import re
+import urllib
+
+class ZxJDBCConnector(Connector):
+    driver='zxjdbc'
+    
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = True
+    supports_unicode_binds = True
+    supports_unicode_statements = False
+    default_paramstyle = 'qmark'
+    
+    jdbc_db_name = None
+    jdbc_driver_name = None
+    
+    @classmethod
+    def dbapi(cls):
+        from com.ziclix.python.sql import zxJDBC
+        return zxJDBC
+
+    def _driver_kwargs(self):
+        """return kw arg dict to be sent to connect()."""
+        return {}
+        
+    def create_connect_args(self, url):
+        hostname = url.host
+        dbname = url.database
+        d, u, p, v = "jdbc:%s://%s/%s" % (self.jdbc_db_name, hostname, dbname), url.username, url.password, self.jdbc_driver_name
+        return [[d, u, p, v], self._driver_kwargs()]
+        
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.ProgrammingError):
+            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
+        elif isinstance(e, self.dbapi.Error):
+            return '[08S01]' in str(e)
+        else:
+            return False
+
+    def _get_server_version_info(self, connection):
+        # use connection.connection.dbversion, and parse appropriately
+        # to get a tuple
+        raise NotImplementedError()
index 9dd2bfe7151978d4421c1956a5f04c2c0e4450ae..229016793952c5121789817443eccaced046b32a 100644 (file)
@@ -1689,7 +1689,10 @@ class MySQLDialect(default.DefaultDialect):
     supports_alter = True
     # identifiers are 64, however aliases can be 255...
     max_identifier_length = 255
+    
     supports_sane_rowcount = True
+    supports_sane_multi_rowcount = False
+    
     default_paramstyle = 'format'
     colspecs = colspecs
     
@@ -1701,11 +1704,6 @@ class MySQLDialect(default.DefaultDialect):
     def __init__(self, use_ansiquotes=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
 
-    def do_executemany(self, cursor, statement, parameters, context=None):
-        rowcount = cursor.executemany(statement, parameters)
-        if context is not None:
-            context._rowcount = rowcount
-
     def do_commit(self, connection):
         """Execute a COMMIT."""
 
@@ -1848,6 +1846,7 @@ class MySQLDialect(default.DefaultDialect):
         charset = self._connection_charset
         rp = connection.execute("SHOW FULL TABLES FROM %s" %
                 self.identifier_preparer.quote_identifier(schema))
+        
         return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
                                                     if row[1] == 'BASE TABLE']
 
@@ -1973,7 +1972,7 @@ class MySQLDialect(default.DefaultDialect):
         except AttributeError:
             preparer = self.identifier_preparer
             if (self.server_version_info < (4, 1) and
-                self._server_use_ansiquotes):
+                self._server_ansiquotes):
                 # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
                 preparer = MySQLIdentifierPreparer(self)
             self.parser = parser = MySQLTableDefinitionParser(self, preparer)
index 5f7636bba94ff39cfef0647d604f9a45b7a92a5f..937c11240462ae999f5ae12502d7a445956f65c1 100644 (file)
@@ -45,6 +45,9 @@ class MySQL_mysqldbCompiler(MySQLCompiler):
 class MySQL_mysqldb(MySQLDialect):
     driver = 'mysqldb'
     supports_unicode_statements = False
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = True
+
     default_paramstyle = 'format'
     execution_ctx_cls = MySQL_mysqldbExecutionContext
     statement_compiler = MySQL_mysqldbCompiler
@@ -53,6 +56,11 @@ class MySQL_mysqldb(MySQLDialect):
     def dbapi(cls):
         return __import__('MySQLdb')
 
+    def do_executemany(self, cursor, statement, parameters, context=None):
+        rowcount = cursor.executemany(statement, parameters)
+        if context is not None:
+            context._rowcount = rowcount
+
     def create_connect_args(self, url):
         opts = url.translate_connect_args(database='db', username='user',
                                           password='passwd')
index 426b23cfdf25fb9b5e62f2259c40a9b9bf418465..de419fbd89d0fba898249211205ca9e273a8b42f 100644 (file)
@@ -24,10 +24,6 @@ class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
-        # Allow user override, won't sniff if force_charset is set.
-        if ('mysql', 'force_charset') in connection.info:
-            return connection.info[('mysql', 'force_charset')]
-
         # Prefer 'character_set_results' for the current connection over the
         # value in the driver.  SET NAMES or individual variable SETs will
         # change the charset without updating the driver's view of the world.
diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py
new file mode 100644 (file)
index 0000000..7d6e370
--- /dev/null
@@ -0,0 +1,67 @@
+from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy import util
+import re
+
+class MySQL_jdbcExecutionContext(MySQLExecutionContext):
+    def _real_lastrowid(self, cursor):
+        return cursor.lastrowid
+
+    def _lastrowid(self, cursor):
+        cursor.execute("SELECT LAST_INSERT_ID()")
+        return cursor.fetchone()[0]
+
+class MySQL_jdbc(ZxJDBCConnector, MySQLDialect):
+    execution_ctx_cls = MySQL_jdbcExecutionContext
+
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+
+    jdbc_db_name = 'mysql'
+    jdbc_driver_name = "org.gjt.mm.mysql.Driver"
+    
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+
+        # Prefer 'character_set_results' for the current connection over the
+        # value in the driver.  SET NAMES or individual variable SETs will
+        # change the charset without updating the driver's view of the world.
+        #
+        # If it's decided that issuing that sort of SQL leaves you SOL, then
+        # this can prefer the driver value.
+        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
+        for key in ('character_set_connection', 'character_set'):
+            if opts.get(key, None):
+                return opts[key]
+
+        util.warn("Could not detect the connection character set.  Assuming latin1.")
+        return 'latin1'
+
+    def _driver_kwargs(self):
+        """return kw arg dict to be sent to connect()."""
+        
+        return {'CHARSET':self.encoding}
+    
+    def _extract_error_code(self, exception):
+        # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
+        
+        m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args))
+        c = m.group(1)
+        if c:
+            return int(c)
+        else:
+            return None
+
+    def _get_server_version_info(self,connection):
+        dbapi_con = connection.connection
+        version = []
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.dbversion):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
+
+dialect = MySQL_jdbc
\ No newline at end of file
index d96efd2dad18628c22ffb0cd5e2b0f7dfc70f9ff..7ab1ac7a483dad587429c89d0fc8c51cb7dc81cb 100644 (file)
@@ -481,7 +481,7 @@ class PGDialect(default.DefaultDialect):
 
     @base.connection_memoize(('dialect', 'default_schema_name'))
     def get_default_schema_name(self, connection):
-        return connection.scalar("select current_schema()", None)
+        return connection.scalar("select current_schema()")
 
     def has_table(self, connection, table_name, schema=None):
         # seems like case gets folded in pg_class...
diff --git a/lib/sqlalchemy/dialects/postgres/zxjdbc.py b/lib/sqlalchemy/dialects/postgres/zxjdbc.py
new file mode 100644 (file)
index 0000000..f968ac9
--- /dev/null
@@ -0,0 +1,18 @@
+from sqlalchemy.dialects.postgres.base import PGDialect
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.engine import default
+
+class Postgres_jdbcExecutionContext(default.DefaultExecutionContext):
+    pass
+
+class Postgres_jdbc(ZxJDBCConnector, PGDialect):
+    execution_ctx_cls = Postgres_jdbcExecutionContext
+
+    jdbc_db_name = 'postgresql'
+    jdbc_driver_name = "org.postgresql.Driver"
+    
+
+    def _get_server_version_info(self, connection):
+        return tuple(int(x) for x in connection.connection.dbversion.split('.'))
+        
+dialect = Postgres_jdbc
\ No newline at end of file
index 77f481028a0868911258922d6675b720cb302bd5..7daf5dbd31657feb0e554ba8634901285b2e1c9f 100644 (file)
@@ -17,7 +17,7 @@ __all__ = ['BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultP
         'Connection', 'DefaultRunner', 'Dialect', 'Engine', 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 
         'RootTransaction', 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', 'connection_memoize']
 
-import inspect, StringIO
+import inspect, StringIO, sys
 from sqlalchemy import exc, schema, util, types, log
 from sqlalchemy.sql import expression
 
@@ -1046,7 +1046,7 @@ class Connection(Connectable):
         
     def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
-            raise exc.DBAPIError.instance(None, None, e)
+            raise exc.DBAPIError.instance_cls(e), (None, None, e), sys.exc_info()[2]
         self._reentrant_error = True
         try:
             if not isinstance(e, self.dialect.dbapi.Error):
@@ -1065,7 +1065,7 @@ class Connection(Connectable):
                 self._autorollback()
                 if self.__close_with_result:
                     self.close()
-            raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+            raise exc.DBAPIError.instance_cls(e), (statement, parameters, e, is_disconnect), sys.exc_info()[2]
         finally:
             del self._reentrant_error
 
@@ -1581,7 +1581,7 @@ class ResultProxy(object):
             self._rowcount = self.context.get_rowcount()
             self.close()
             return
-            
+                
         self._rowcount = None
         self._props = util.populate_column_dict(None)
         self._props.creator = self.__key_fallback()
index 6edf2ae9c408b35e962db45d3c9a656084adf7f4..611e97bcf00b6f4db0df550aaf20bdac1f1c5d3f 100644 (file)
@@ -65,10 +65,10 @@ class Inspector(object):
         if hasattr(engine.dialect, 'inspector'):
             return engine.dialect.inspector(engine)
         return Inspector(engine)
-
+    
+    @property
     def default_schema_name(self):
         return self.dialect.get_default_schema_name(self.conn)
-    default_schema_name = property(default_schema_name)
 
     def get_schema_names(self):
         """Return all schema names.
index 5187ab1927ad0eb22f6256490dd93be6f3b23be1..b1db8625f8a774df391ae7789c6a257948096a83 100644 (file)
@@ -77,7 +77,8 @@ class DefaultEngineStrategy(EngineStrategy):
                 try:
                     return dbapi.connect(*cargs, **cparams)
                 except Exception, e:
-                    raise exc.DBAPIError.instance(None, None, e)
+                    import sys
+                    raise exc.DBAPIError.instance_cls(e), (None, None, e), sys.exc_info()[2]
             creator = kwargs.pop('creator', connect)
 
             poolclass = (kwargs.pop('poolclass', None) or
index d1af6d385dbce5e6d604345bb9a4ea33144aad93..6cc43d7f267807a2f5f71fcfe4f71ae1349f97de 100644 (file)
@@ -103,7 +103,8 @@ class DBAPIError(SQLAlchemyError):
 
     """
 
-    def instance(cls, statement, params, orig, connection_invalidated=False):
+    @classmethod
+    def instance_cls(cls, orig):
         # Don't ever wrap these, just return them directly as if
         # DBAPIError didn't exist.
         if isinstance(orig, (KeyboardInterrupt, SystemExit)):
@@ -114,8 +115,7 @@ class DBAPIError(SQLAlchemyError):
             if name in glob and issubclass(glob[name], DBAPIError):
                 cls = glob[name]
 
-        return cls(statement, params, orig, connection_invalidated)
-    instance = classmethod(instance)
+        return cls
 
     def __init__(self, statement, params, orig, connection_invalidated=False):
         try:
index 1642e7394ad5b2a15f4f74950cd75c8a58811340..aa012f27750a5338ec51f2026677f64318fa802f 100644 (file)
@@ -139,10 +139,11 @@ class QueryableAttribute(interfaces.PropComparator):
     def __str__(self):
         return repr(self.parententity) + "." + self.property.key
 
-    @property
-    def property(self):
-        return self.comparator.property
-
+#    @property
+#    def property(self):
+#        return self.comparator.property
+    
+QueryableAttribute.property = property(lambda self:self.comparator.property)
 
 class InstrumentedAttribute(QueryableAttribute):
     """Public-facing descriptor, placed in the mapped class dictionary."""
@@ -833,6 +834,7 @@ class InstanceState(object):
     def __init__(self, obj, manager):
         self.class_ = obj.__class__
         self.manager = manager
+        
         self.obj = weakref.ref(obj, self._cleanup)
         self.dict = obj.__dict__
         self.modified = False
@@ -844,11 +846,17 @@ class InstanceState(object):
         
     def detach(self):
         if self.session_id:
-            del self.session_id
+            try:
+                del self.session_id
+            except AttributeError:
+                pass
 
     def dispose(self):
         if self.session_id:
-            del self.session_id
+            try:
+                del self.session_id
+            except AttributeError:
+                pass
         del self.obj
         del self.dict
     
index 6e6500af152173291bd4150fb5537461381b05fe..bb4a10e605a2be049ec8a81a31aaa38f815fae31 100644 (file)
@@ -340,6 +340,15 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
     def _assert_data_noautoincrement(self, table):
         table.insert().execute({'id':30, 'data':'d1'})
+        
+        if testing.db.driver == 'pg8000':
+            exception_cls = ProgrammingError
+        else:
+            exception_cls = IntegrityError
+        
+        self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+        self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
+
         try:
             table.insert().execute({'data':'d2'})
             assert False
@@ -367,16 +376,9 @@ class InsertTest(TestBase, AssertsExecutionResults):
         m2 = MetaData(testing.db)
         table = Table(table.name, m2, autoload=True)
         table.insert().execute({'id':30, 'data':'d1'})
-        try:
-            table.insert().execute({'data':'d2'})
-            assert False
-        except exc.IntegrityError, e:
-            assert "violates not-null constraint" in str(e)
-        try:
-            table.insert().execute({'data':'d2'}, {'data':'d3'})
-            assert False
-        except exc.IntegrityError, e:
-            assert "violates not-null constraint" in str(e)
+
+        self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+        self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
 
         table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
         table.insert(inline=True).execute({'id':33, 'data':'d4'})
@@ -858,7 +860,7 @@ class TimeStampTest(TestBase, AssertsExecutionResults):
         self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0))
 
 class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgres+psycopg2'
 
     def setUpAll(self):
         global ss_engine
index e1e18896a829f3abc544772504bac4a60be34474..a51d9823dc39e4a9caea780d64bf4d95e305023c 100644 (file)
@@ -190,7 +190,7 @@ class GetTest(QueryTest):
         assert u.addresses[0].email_address == 'jack@bean.com'
         assert u.orders[1].items[2].description == 'item 5'
 
-    @testing.fails_on_everything_except('sqlite', '+pyodbc')
+    @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc')
     def test_query_str(self):
         s = create_session()
         q = s.query(User).filter(User.id==1)
@@ -1748,8 +1748,8 @@ class MixedEntitiesTest(QueryTest):
         sess = create_session()
 
         q = sess.query(User)
-        q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%')))
-        self.assertEquals(list(q2), [(True, 1), (False, 3)])
+        q2 = q.order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'))
+        self.assertEquals(list(q2), [(True,), (False,), (False,), (False,)])
 
     def test_correlated_subquery(self):
         """test that a subquery constructed from ORM attributes doesn't leak out 
index d0da2bf054e07e097122700e2ac3a6390bc479a1..700b2ca5f626f9ef0cf2180f7e2c7aca4e7b7532 100644 (file)
@@ -72,8 +72,11 @@ class QueryTest(TestBase):
             if result.lastrow_has_defaults():
                 criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
                 row = table.select(criterion).execute().fetchone()
-                for c in table.c:
-                    ret[c.key] = row[c]
+                try:
+                    for c in table.c:
+                        ret[c.key] = row[c]
+                finally:
+                    row.close()
             return ret
 
         for supported, table, values, assertvalues in [
@@ -524,30 +527,46 @@ class QueryTest(TestBase):
             users.select().alias(users.name),
         ):
             row = s.select(use_labels=True).execute().fetchone()
-            assert row[s.c.user_id] == 7
-            assert row[s.c.user_name] == 'ed'
+            try:
+                assert row[s.c.user_id] == 7
+                assert row[s.c.user_name] == 'ed'
+            finally:
+                row.close()
 
     def test_keys(self):
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select().execute().fetchone()
-        self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+        try:
+            self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+        finally:
+            r.close()
 
     def test_items(self):
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select().execute().fetchone()
-        self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
+        try:
+            self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
+        finally:
+            r.close()
 
     def test_len(self):
         users.insert().execute(user_id=1, user_name='foo')
-        r = users.select().execute().fetchone()
-        self.assertEqual(len(r), 2)
-        r.close()
+        try:
+            r = users.select().execute().fetchone()
+            self.assertEqual(len(r), 2)
+        finally:
+            r.close()
+            
         r = testing.db.execute('select user_name, user_id from query_users').fetchone()
-        self.assertEqual(len(r), 2)
-        r.close()
-        r = testing.db.execute('select user_name from query_users').fetchone()
-        self.assertEqual(len(r), 1)
-        r.close()
+        try:
+            self.assertEqual(len(r), 2)
+        finally:
+            r.close()
+        try:
+            r = testing.db.execute('select user_name from query_users').fetchone()
+            self.assertEqual(len(r), 1)
+        finally:
+            r.close()
 
     def test_cant_execute_join(self):
         try:
index cef4c6e1dcf641808034fe455f8aaca1584d2a82..5d01e9f4ed04b65ca9316d5c26273ed9388bf033 100644 (file)
@@ -266,30 +266,23 @@ def _prep_testing_database(options, file_config):
     from testlib import engines
     from sqlalchemy import schema
 
-    try:
-        # also create alt schemas etc. here?
-        if options.dropfirst:
-            e = engines.utf8_engine()
-            existing = e.table_names()
-            if existing:
-                if not options.quiet:
-                    print "Dropping existing tables in database: " + db_url
-                    try:
-                        print "Tables: %s" % ', '.join(existing)
-                    except:
-                        pass
-                    print "Abort within 5 seconds..."
-                    time.sleep(5)
-                md = schema.MetaData(e, reflect=True)
-                md.drop_all()
-            e.dispose()
-    except (KeyboardInterrupt, SystemExit):
-        raise
-    except Exception, e:
-        if not options.quiet:
-            warnings.warn(RuntimeWarning(
-                "Error checking for existing tables in testing "
-                "database: %s" % e))
+    # also create alt schemas etc. here?
+    if options.dropfirst:
+        e = engines.utf8_engine()
+        existing = e.table_names()
+        if existing:
+            if not options.quiet:
+                print "Dropping existing tables in database: " + db_url
+                try:
+                    print "Tables: %s" % ', '.join(existing)
+                except:
+                    pass
+                print "Abort within 5 seconds..."
+                time.sleep(5)
+            md = schema.MetaData(e, reflect=True)
+            md.drop_all()
+        e.dispose()
+
 post_configure['prep_db'] = _prep_testing_database
 
 def _set_table_options(options, file_config):
index 30500068ca16b9221c13d54c2c408294f0432e2e..89c08ac47dfb689e68b312da99ea6195ff730b99 100644 (file)
@@ -288,6 +288,10 @@ def _server_version(bind=None):
 
     if bind is None:
         bind = config.db
+    
+    # force metadata to be retrieved
+    bind.connect()
+    
     return getattr(bind.dialect, 'server_version_info', ())
 
 def skip_if(predicate, reason=None):