From: Mike Bayer Date: Mon, 30 Mar 2009 22:32:36 +0000 (+0000) Subject: - jython support. works OK for expressions, there's a major weakref bug in ORM tho X-Git-Tag: rel_0_6_6~252 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=551d3a3ca50a14bca4f5ef63b1da5105984fcbe5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - jython support. works OK for expressions, there's a major weakref bug in ORM tho - reraises of exceptions pass along the original stack trace --- diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py index e69de29bb2..eb5c95b160 100644 --- a/lib/sqlalchemy/connectors/zxJDBC.py +++ b/lib/sqlalchemy/connectors/zxJDBC.py @@ -0,0 +1,45 @@ +from sqlalchemy.connectors import Connector + +import sys +import re +import urllib + +class ZxJDBCConnector(Connector): + driver='zxjdbc' + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode_binds = True + supports_unicode_statements = False + default_paramstyle = 'qmark' + + jdbc_db_name = None + jdbc_driver_name = None + + @classmethod + def dbapi(cls): + from com.ziclix.python.sql import zxJDBC + return zxJDBC + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + return {} + + def create_connect_args(self, url): + hostname = url.host + dbname = url.database + d, u, p, v = "jdbc:%s://%s/%s" % (self.jdbc_db_name, hostname, dbname), url.username, url.password, self.jdbc_driver_name + return [[d, u, p, v], self._driver_kwargs()] + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def _get_server_version_info(self, connection): + # use connection.connection.dbversion, and parse appropriately + # to get a tuple + raise NotImplementedError() diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 9dd2bfe715..2290167939 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1689,7 +1689,10 @@ class MySQLDialect(default.DefaultDialect): supports_alter = True # identifiers are 64, however aliases can be 255... max_identifier_length = 255 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + default_paramstyle = 'format' colspecs = colspecs @@ -1701,11 +1704,6 @@ class MySQLDialect(default.DefaultDialect): def __init__(self, use_ansiquotes=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) - def do_executemany(self, cursor, statement, parameters, context=None): - rowcount = cursor.executemany(statement, parameters) - if context is not None: - context._rowcount = rowcount - def do_commit(self, connection): """Execute a COMMIT.""" @@ -1848,6 +1846,7 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset rp = connection.execute("SHOW FULL TABLES FROM %s" % self.identifier_preparer.quote_identifier(schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ if row[1] == 'BASE TABLE'] @@ -1973,7 +1972,7 @@ class MySQLDialect(default.DefaultDialect): except AttributeError: preparer = self.identifier_preparer if (self.server_version_info < (4, 1) and - self._server_use_ansiquotes): + self._server_ansiquotes): # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = MySQLIdentifierPreparer(self) self.parser = parser = MySQLTableDefinitionParser(self, preparer) diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 5f7636bba9..937c112404 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -45,6 +45,9 @@ class MySQL_mysqldbCompiler(MySQLCompiler): class MySQL_mysqldb(MySQLDialect): driver = 'mysqldb' supports_unicode_statements = False + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + default_paramstyle = 'format' execution_ctx_cls = MySQL_mysqldbExecutionContext statement_compiler = MySQL_mysqldbCompiler @@ -53,6 +56,11 @@ class MySQL_mysqldb(MySQLDialect): def dbapi(cls): return __import__('MySQLdb') + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + def create_connect_args(self, url): opts = url.translate_connect_args(database='db', username='user', password='passwd') diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 426b23cfdf..de419fbd89 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -24,10 +24,6 @@ class MySQL_pyodbc(PyODBCConnector, MySQLDialect): def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" - # Allow user override, won't sniff if force_charset is set. - if ('mysql', 'force_charset') in connection.info: - return connection.info[('mysql', 'force_charset')] - # Prefer 'character_set_results' for the current connection over the # value in the driver. SET NAMES or individual variable SETs will # change the charset without updating the driver's view of the world. diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py new file mode 100644 index 0000000000..7d6e3703de --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -0,0 +1,67 @@ +from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy import util +import re + +class MySQL_jdbcExecutionContext(MySQLExecutionContext): + def _real_lastrowid(self, cursor): + return cursor.lastrowid + + def _lastrowid(self, cursor): + cursor.execute("SELECT LAST_INSERT_ID()") + return cursor.fetchone()[0] + +class MySQL_jdbc(ZxJDBCConnector, MySQLDialect): + execution_ctx_cls = MySQL_jdbcExecutionContext + + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + jdbc_db_name = 'mysql' + jdbc_driver_name = "org.gjt.mm.mysql.Driver" + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + + return {'CHARSET':self.encoding} + + def _extract_error_code(self, exception): + # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' () + + m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + + def _get_server_version_info(self,connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.dbversion): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + +dialect = MySQL_jdbc \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index d96efd2dad..7ab1ac7a48 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -481,7 +481,7 @@ class PGDialect(default.DefaultDialect): @base.connection_memoize(('dialect', 'default_schema_name')) def get_default_schema_name(self, connection): - return connection.scalar("select current_schema()", None) + return connection.scalar("select current_schema()") def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... diff --git a/lib/sqlalchemy/dialects/postgres/zxjdbc.py b/lib/sqlalchemy/dialects/postgres/zxjdbc.py new file mode 100644 index 0000000000..f968ac9851 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres/zxjdbc.py @@ -0,0 +1,18 @@ +from sqlalchemy.dialects.postgres.base import PGDialect +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.engine import default + +class Postgres_jdbcExecutionContext(default.DefaultExecutionContext): + pass + +class Postgres_jdbc(ZxJDBCConnector, PGDialect): + execution_ctx_cls = Postgres_jdbcExecutionContext + + jdbc_db_name = 'postgresql' + jdbc_driver_name = "org.postgresql.Driver" + + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.dbversion.split('.')) + +dialect = Postgres_jdbc \ No newline at end of file diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 77f481028a..7daf5dbd31 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -17,7 +17,7 @@ __all__ = ['BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultP 'Connection', 'DefaultRunner', 'Dialect', 'Engine', 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction', 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', 'connection_memoize'] -import inspect, StringIO +import inspect, StringIO, sys from sqlalchemy import exc, schema, util, types, log from sqlalchemy.sql import expression @@ -1046,7 +1046,7 @@ class Connection(Connectable): def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): - raise exc.DBAPIError.instance(None, None, e) + raise exc.DBAPIError.instance_cls(e), (None, None, e), sys.exc_info()[2] self._reentrant_error = True try: if not isinstance(e, self.dialect.dbapi.Error): @@ -1065,7 +1065,7 @@ class Connection(Connectable): self._autorollback() if self.__close_with_result: self.close() - raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + raise exc.DBAPIError.instance_cls(e), (statement, parameters, e, is_disconnect), sys.exc_info()[2] finally: del self._reentrant_error @@ -1581,7 +1581,7 @@ class ResultProxy(object): self._rowcount = self.context.get_rowcount() self.close() return - + self._rowcount = None self._props = util.populate_column_dict(None) self._props.creator = self.__key_fallback() diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 6edf2ae9c4..611e97bcf0 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -65,10 +65,10 @@ class Inspector(object): if hasattr(engine.dialect, 'inspector'): return engine.dialect.inspector(engine) return Inspector(engine) - + + @property def default_schema_name(self): return self.dialect.get_default_schema_name(self.conn) - default_schema_name = property(default_schema_name) def get_schema_names(self): """Return all schema names. diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 5187ab1927..b1db8625f8 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -77,7 +77,8 @@ class DefaultEngineStrategy(EngineStrategy): try: return dbapi.connect(*cargs, **cparams) except Exception, e: - raise exc.DBAPIError.instance(None, None, e) + import sys + raise exc.DBAPIError.instance_cls(e), (None, None, e), sys.exc_info()[2] creator = kwargs.pop('creator', connect) poolclass = (kwargs.pop('poolclass', None) or diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index d1af6d385d..6cc43d7f26 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -103,7 +103,8 @@ class DBAPIError(SQLAlchemyError): """ - def instance(cls, statement, params, orig, connection_invalidated=False): + @classmethod + def instance_cls(cls, orig): # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. if isinstance(orig, (KeyboardInterrupt, SystemExit)): @@ -114,8 +115,7 @@ class DBAPIError(SQLAlchemyError): if name in glob and issubclass(glob[name], DBAPIError): cls = glob[name] - return cls(statement, params, orig, connection_invalidated) - instance = classmethod(instance) + return cls def __init__(self, statement, params, orig, connection_invalidated=False): try: diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 1642e7394a..aa012f2775 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -139,10 +139,11 @@ class QueryableAttribute(interfaces.PropComparator): def __str__(self): return repr(self.parententity) + "." + self.property.key - @property - def property(self): - return self.comparator.property - +# @property +# def property(self): +# return self.comparator.property + +QueryableAttribute.property = property(lambda self:self.comparator.property) class InstrumentedAttribute(QueryableAttribute): """Public-facing descriptor, placed in the mapped class dictionary.""" @@ -833,6 +834,7 @@ class InstanceState(object): def __init__(self, obj, manager): self.class_ = obj.__class__ self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) self.dict = obj.__dict__ self.modified = False @@ -844,11 +846,17 @@ class InstanceState(object): def detach(self): if self.session_id: - del self.session_id + try: + del self.session_id + except AttributeError: + pass def dispose(self): if self.session_id: - del self.session_id + try: + del self.session_id + except AttributeError: + pass del self.obj del self.dict diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index 6e6500af15..bb4a10e605 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -340,6 +340,15 @@ class InsertTest(TestBase, AssertsExecutionResults): def _assert_data_noautoincrement(self, table): table.insert().execute({'id':30, 'data':'d1'}) + + if testing.db.driver == 'pg8000': + exception_cls = ProgrammingError + else: + exception_cls = IntegrityError + + self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) + try: table.insert().execute({'data':'d2'}) assert False @@ -367,16 +376,9 @@ class InsertTest(TestBase, AssertsExecutionResults): m2 = MetaData(testing.db) table = Table(table.name, m2, autoload=True) table.insert().execute({'id':30, 'data':'d1'}) - try: - table.insert().execute({'data':'d2'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) - try: - table.insert().execute({'data':'d2'}, {'data':'d3'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) + + self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + self.assertRaisesMessage(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) table.insert(inline=True).execute({'id':33, 'data':'d4'}) @@ -858,7 +860,7 @@ class TimeStampTest(TestBase, AssertsExecutionResults): self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0)) class ServerSideCursorsTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgres+psycopg2' def setUpAll(self): global ss_engine diff --git a/test/orm/query.py b/test/orm/query.py index e1e18896a8..a51d9823dc 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -190,7 +190,7 @@ class GetTest(QueryTest): assert u.addresses[0].email_address == 'jack@bean.com' assert u.orders[1].items[2].description == 'item 5' - @testing.fails_on_everything_except('sqlite', '+pyodbc') + @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc') def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) @@ -1748,8 +1748,8 @@ class MixedEntitiesTest(QueryTest): sess = create_session() q = sess.query(User) - q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) - self.assertEquals(list(q2), [(True, 1), (False, 3)]) + q2 = q.order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%')) + self.assertEquals(list(q2), [(True,), (False,), (False,), (False,)]) def test_correlated_subquery(self): """test that a subquery constructed from ORM attributes doesn't leak out diff --git a/test/sql/query.py b/test/sql/query.py index d0da2bf054..700b2ca5f6 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -72,8 +72,11 @@ class QueryTest(TestBase): if result.lastrow_has_defaults(): criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) row = table.select(criterion).execute().fetchone() - for c in table.c: - ret[c.key] = row[c] + try: + for c in table.c: + ret[c.key] = row[c] + finally: + row.close() return ret for supported, table, values, assertvalues in [ @@ -524,30 +527,46 @@ class QueryTest(TestBase): users.select().alias(users.name), ): row = s.select(use_labels=True).execute().fetchone() - assert row[s.c.user_id] == 7 - assert row[s.c.user_name] == 'ed' + try: + assert row[s.c.user_id] == 7 + assert row[s.c.user_name] == 'ed' + finally: + row.close() def test_keys(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() - self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) + try: + self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) + finally: + r.close() def test_items(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() - self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) + try: + self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) + finally: + r.close() def test_len(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() - self.assertEqual(len(r), 2) - r.close() + try: + r = users.select().execute().fetchone() + self.assertEqual(len(r), 2) + finally: + r.close() + r = testing.db.execute('select user_name, user_id from query_users').fetchone() - self.assertEqual(len(r), 2) - r.close() - r = testing.db.execute('select user_name from query_users').fetchone() - self.assertEqual(len(r), 1) - r.close() + try: + self.assertEqual(len(r), 2) + finally: + r.close() + try: + r = testing.db.execute('select user_name from query_users').fetchone() + self.assertEqual(len(r), 1) + finally: + r.close() def test_cant_execute_join(self): try: diff --git a/test/testlib/config.py b/test/testlib/config.py index cef4c6e1dc..5d01e9f4ed 100644 --- a/test/testlib/config.py +++ b/test/testlib/config.py @@ -266,30 +266,23 @@ def _prep_testing_database(options, file_config): from testlib import engines from sqlalchemy import schema - try: - # also create alt schemas etc. here? - if options.dropfirst: - e = engines.utf8_engine() - existing = e.table_names() - if existing: - if not options.quiet: - print "Dropping existing tables in database: " + db_url - try: - print "Tables: %s" % ', '.join(existing) - except: - pass - print "Abort within 5 seconds..." - time.sleep(5) - md = schema.MetaData(e, reflect=True) - md.drop_all() - e.dispose() - except (KeyboardInterrupt, SystemExit): - raise - except Exception, e: - if not options.quiet: - warnings.warn(RuntimeWarning( - "Error checking for existing tables in testing " - "database: %s" % e)) + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + if not options.quiet: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + post_configure['prep_db'] = _prep_testing_database def _set_table_options(options, file_config): diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 30500068ca..89c08ac47d 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -288,6 +288,10 @@ def _server_version(bind=None): if bind is None: bind = config.db + + # force metadata to be retrieved + bind.connect() + return getattr(bind.dialect, 'server_version_info', ()) def skip_if(predicate, reason=None):