]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merge 0.6 series to trunk.
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Aug 2009 21:11:27 +0000 (21:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Aug 2009 21:11:27 +0000 (21:11 +0000)
234 files changed:
06CHANGES [new file with mode: 0644]
CHANGES
README.unittests
README_THIS_IS_06 [new file with mode: 0644]
convert.py
doc/build/copyright.rst
doc/build/dbengine.rst
doc/build/metadata.rst
doc/build/reference/dialects/access.rst
doc/build/reference/dialects/firebird.rst
doc/build/reference/dialects/index.rst
doc/build/reference/dialects/informix.rst
doc/build/reference/dialects/maxdb.rst
doc/build/reference/dialects/mssql.rst
doc/build/reference/dialects/mysql.rst
doc/build/reference/dialects/oracle.rst
doc/build/reference/dialects/postgres.rst [deleted file]
doc/build/reference/dialects/postgresql.rst [new file with mode: 0644]
doc/build/reference/dialects/sqlite.rst
doc/build/reference/dialects/sybase.rst
doc/build/reference/sqlalchemy/connections.rst
doc/build/reference/sqlalchemy/pooling.rst
doc/build/reference/sqlalchemy/types.rst
doc/build/session.rst
doc/build/sqlexpression.rst
doc/build/testdocs.py
examples/postgis/postgis.py
examples/query_caching/query_caching.py
ez_setup.py [new file with mode: 0644]
lib/sqlalchemy/__init__.py
lib/sqlalchemy/connectors/__init__.py [new file with mode: 0644]
lib/sqlalchemy/connectors/mxodbc.py [new file with mode: 0644]
lib/sqlalchemy/connectors/pyodbc.py [new file with mode: 0644]
lib/sqlalchemy/connectors/zxJDBC.py [new file with mode: 0644]
lib/sqlalchemy/databases/__init__.py
lib/sqlalchemy/databases/firebird.py [deleted file]
lib/sqlalchemy/databases/information_schema.py [deleted file]
lib/sqlalchemy/databases/mssql.py [deleted file]
lib/sqlalchemy/databases/mxODBC.py [deleted file]
lib/sqlalchemy/databases/oracle.py [deleted file]
lib/sqlalchemy/databases/sqlite.py [deleted file]
lib/sqlalchemy/databases/sybase.py [deleted file]
lib/sqlalchemy/dialects/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/access/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/access/base.py [moved from lib/sqlalchemy/databases/access.py with 96% similarity]
lib/sqlalchemy/dialects/firebird/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/firebird/base.py [new file with mode: 0644]
lib/sqlalchemy/dialects/firebird/kinterbasdb.py [new file with mode: 0644]
lib/sqlalchemy/dialects/informix/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/informix/base.py [moved from lib/sqlalchemy/databases/informix.py with 59% similarity]
lib/sqlalchemy/dialects/informix/informixdb.py [new file with mode: 0644]
lib/sqlalchemy/dialects/maxdb/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/maxdb/base.py [moved from lib/sqlalchemy/databases/maxdb.py with 90% similarity]
lib/sqlalchemy/dialects/maxdb/sapdb.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/adodbapi.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/information_schema.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/pymssql.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/pyodbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py [moved from lib/sqlalchemy/databases/mysql.py with 66% similarity]
lib/sqlalchemy/dialects/mysql/mysqldb.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/pyodbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/zxjdbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/cx_oracle.py [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/zxjdbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgres.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py [moved from lib/sqlalchemy/databases/postgres.py with 51% similarity]
lib/sqlalchemy/dialects/postgresql/pg8000.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/psycopg2.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/pypostgresql.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/zxjdbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/pysqlite.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sybase/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sybase/base.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sybase/mxodbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sybase/pyodbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sybase/schema.py [new file with mode: 0644]
lib/sqlalchemy/dialects/type_migration_guidelines.txt [new file with mode: 0644]
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/ddl.py [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/reflection.py [new file with mode: 0644]
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/engine/url.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/ext/orderinglist.py
lib/sqlalchemy/ext/serializer.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/interfaces.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/evaluator.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/queue.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/test/assertsql.py
lib/sqlalchemy/test/config.py
lib/sqlalchemy/test/engines.py
lib/sqlalchemy/test/noseplugin.py
lib/sqlalchemy/test/profiling.py
lib/sqlalchemy/test/requires.py
lib/sqlalchemy/test/schema.py
lib/sqlalchemy/test/testing.py
lib/sqlalchemy/test/util.py [new file with mode: 0644]
lib/sqlalchemy/topological.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
sa2to3.py [new file with mode: 0644]
setup.py
sqla_nose.py [new file with mode: 0644]
test/aaa_profiling/test_compiler.py
test/aaa_profiling/test_memusage.py
test/aaa_profiling/test_pool.py
test/aaa_profiling/test_zoomark.py
test/aaa_profiling/test_zoomark_orm.py
test/base/test_dependency.py
test/base/test_except.py
test/base/test_utils.py
test/dialect/test_firebird.py
test/dialect/test_informix.py
test/dialect/test_maxdb.py
test/dialect/test_mssql.py
test/dialect/test_mysql.py
test/dialect/test_oracle.py
test/dialect/test_postgresql.py [moved from test/dialect/test_postgres.py with 66% similarity]
test/dialect/test_sqlite.py
test/engine/test_bind.py
test/engine/test_ddlevents.py
test/engine/test_execute.py
test/engine/test_metadata.py
test/engine/test_parseconnect.py
test/engine/test_pool.py
test/engine/test_reconnect.py
test/engine/test_reflection.py
test/engine/test_transaction.py
test/ext/test_associationproxy.py
test/ext/test_compiler.py
test/ext/test_declarative.py
test/ext/test_serializer.py
test/orm/_base.py
test/orm/_fixtures.py
test/orm/inheritance/test_abc_inheritance.py
test/orm/inheritance/test_abc_polymorphic.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_concrete.py
test/orm/inheritance/test_magazine.py
test/orm/inheritance/test_manytomany.py
test/orm/inheritance/test_poly_linked_list.py
test/orm/inheritance/test_polymorph2.py
test/orm/inheritance/test_productspec.py
test/orm/inheritance/test_query.py
test/orm/inheritance/test_selects.py
test/orm/inheritance/test_single.py
test/orm/sharding/test_shard.py
test/orm/test_association.py
test/orm/test_assorted_eager.py
test/orm/test_attributes.py
test/orm/test_cascade.py
test/orm/test_collection.py
test/orm/test_cycles.py
test/orm/test_defaults.py
test/orm/test_dynamic.py
test/orm/test_eager_relations.py
test/orm/test_expire.py
test/orm/test_generative.py
test/orm/test_instrumentation.py
test/orm/test_lazy_relations.py
test/orm/test_mapper.py
test/orm/test_merge.py
test/orm/test_naturalpks.py
test/orm/test_onetoone.py
test/orm/test_pickled.py
test/orm/test_query.py
test/orm/test_relationships.py
test/orm/test_scoping.py
test/orm/test_selectable.py
test/orm/test_session.py
test/orm/test_transaction.py
test/orm/test_unitofwork.py
test/orm/test_utils.py
test/perf/insertspeed.py
test/perf/masscreate.py
test/perf/masscreate2.py
test/perf/masseagerload.py
test/perf/massload.py
test/perf/masssave.py
test/perf/objselectspeed.py
test/perf/objupdatespeed.py
test/perf/ormsession.py
test/perf/poolload.py
test/perf/sessions.py
test/perf/wsgi.py
test/sql/test_constraints.py
test/sql/test_defaults.py
test/sql/test_functions.py
test/sql/test_labels.py
test/sql/test_query.py
test/sql/test_quote.py
test/sql/test_returning.py [new file with mode: 0644]
test/sql/test_select.py
test/sql/test_selectable.py
test/sql/test_types.py
test/sql/test_unicode.py
test/zblog/mappers.py
test/zblog/tables.py
test/zblog/test_zblog.py
test/zblog/user.py

diff --git a/06CHANGES b/06CHANGES
new file mode 100644 (file)
index 0000000..4c7f9ca
--- /dev/null
+++ b/06CHANGES
@@ -0,0 +1,177 @@
+- orm
+    - the 'expire' option on query.update() has been renamed to 'fetch', thus matching
+      that of query.delete()
+    - query.update() and query.delete() both default to 'evaluate' for the synchronize 
+      strategy.
+    - the 'synchronize' strategy for update() and delete() raises an error on failure.  
+      There is no implicit fallback onto "fetch".   Failure of evaluation is based
+      on the structure of criteria, so success/failure is deterministic based on 
+      code structure.
+    - the "dont_load=True" flag on Session.merge() is deprecated and is now 
+      "load=False".
+      
+- sql
+    - returning() support is native to insert(), update(), delete().  Implementations
+      of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
+      Oracle.   returning() can be called explicitly with column expressions which
+      are then returned in the resultset, usually via fetchone() or first().
+      
+      insert() constructs will also use RETURNING implicitly to get newly
+      generated primary key values, if the database version in use supports it
+      (a version number check is performed).   This occurs if no end-user
+      returning() was specified.
+      
+    - Databases which rely upon postfetch of "last inserted id" to get at a 
+      generated sequence value (i.e. MySQL, MS-SQL) now work correctly
+      when there is a composite primary key where the "autoincrement" column
+      is not the first primary key column in the table.
+
+    - the last_inserted_ids() method has been renamed to the descriptor
+      "inserted_primary_key".
+      
+- engines
+    - transaction isolation level may be specified with
+      create_engine(... isolation_level="..."); available on
+      postgresql and sqlite. [ticket:443]
+    - added first() method to ResultProxy, returns first row and closes
+      result set immediately.
+
+- schema
+    - deprecated metadata.connect() and threadlocalmetadata.connect() have been 
+      removed - send the "bind" attribute to bind a metadata.
+    - deprecated metadata.table_iterator() method removed (use sorted_tables)
+    - the "metadata" argument is removed from DefaultGenerator and subclasses,
+      but remains locally present on Sequence, which is a standalone construct
+      in DDL.
+    - Removed public mutability from Index and Constraint objects:
+        - ForeignKeyConstraint.append_element()
+        - Index.append_column()
+        - UniqueConstraint.append_column()
+        - PrimaryKeyConstraint.add()
+        - PrimaryKeyConstraint.remove()
+      These should be constructed declaratively (i.e. in one construction).
+    - UniqueConstraint, Index, PrimaryKeyConstraint all accept lists
+      of column names or column objects as arguments.
+    - Other removed things:
+        - Table.key (no idea what this was for)
+        - Table.primary_key is not assignable - use table.append_constraint(PrimaryKeyConstraint(...))
+        - Column.bind       (get via column.table.bind)
+        - Column.metadata   (get via column.table.metadata)
+    - the use_alter flag on ForeignKey is now a shortcut option for operations that 
+      can be hand-constructed using the DDL() event system.  A side effect of this refactor
+      is that ForeignKeyConstraint objects with use_alter=True will *not* be emitted on
+      SQLite, which does not support ALTER for foreign keys.  This has no effect on SQLite's 
+      behavior since SQLite does not actually honor FOREIGN KEY constraints.
+    
+- DDL
+    - the DDL() system has been greatly expanded:
+        - CreateTable()
+        - DropTable()
+        - AddConstraint()
+        - DropConstraint()
+        - CreateIndex()
+        - DropIndex()
+        - CreateSequence()
+        - DropSequence()
+        - these support "on" and "execute-at()" just like
+          plain DDL() does.
+    - the "on" callable passed to DDL() needs to accept **kw arguments.
+      In the case of MetaData before/after create/drop, the list of 
+      Table objects for which CREATE/DROP DDL is to be issued is passed
+      as the kw argument "tables".   This is necessary for metadata-level
+      DDL that is dependent on the presence of specific tables.
+    
+- dialect refactor
+    - the "owner" keyword argument is removed from Table.  Use "schema" to 
+      represent any namespaces to be prepended to the table name.
+    - server_version_info becomes a static attribute.
+    - dialects receive an initialize() event on initial connection to
+      determine connection properties.
+    - dialects receive a visit_pool event have an opportunity to
+      establish pool listeners.
+    - cached TypeEngine classes are cached per-dialect class instead of per-dialect.
+    - Deprecated Dialect.get_params() removed.
+    - Dialect.get_rowcount() has been renamed to a descriptor "rowcount", and calls 
+      cursor.rowcount directly.  Dialects which need to hardwire a rowcount in for 
+      certain calls should override the method to provide different behavior.
+    - functions and operators generated by the compiler now use (almost) regular
+      dispatch functions of the form "visit_<opname>" and "visit_<funcname>_fn" 
+      to provide customed processing.  This replaces the need to copy the "functions" 
+      and "operators" dictionaries in compiler subclasses with straightforward
+      visitor methods, and also allows compiler subclasses complete control over 
+      rendering, as the full _Function or _BinaryExpression object is passed in.
+
+- postgresql
+    - the "postgres" dialect is now named "postgresql" !   Connection strings look
+      like:
+      
+           postgresql://scott:tiger@localhost/test
+           postgresql+pg8000://scott:tiger@localhost/test
+    
+       The "postgres" name remains for backwards compatiblity in the following ways:
+       
+           - There is a "postgres.py" dummy dialect which allows old URLs to work,
+           i.e.  postgres://scott:tiger@localhost/test
+           
+           - The "postgres" name can be imported from the old "databases" module,
+           i.e. "from sqlalchemy.databases import postgres" as well as "dialects",
+           "from sqlalchemy.dialects.postgres import base as pg", will send 
+           a deprecation warning.
+           
+           - Special expression arguments are now named "postgresql_returning"
+           and "postgresql_where", but the older "postgres_returning" and
+           "postgres_where" names still work with a deprecation warning.
+       
+- mysql
+    - all the _detect_XXX() functions now run once underneath dialect.initialize()
+    
+- oracle
+    - support for cx_Oracle's "native unicode" mode which does not require NLS_LANG
+      to be set.  Use the latest 5.0.2 or later of cx_oracle.  
+    - an NCLOB type is added to the base types.
+    - func.char_length is a generic function for LENGTH
+    - ForeignKey() which includes onupdate=<value> will emit a warning, not 
+      emit ON UPDATE CASCADE which is unsupported by oracle
+    - the keys() method of RowProxy() now returns the result column names *normalized*
+      to be SQLAlchemy case insensitive names.  This means they will be lower case 
+      for case insensitive names, whereas the DBAPI would normally return them 
+      as UPPERCASE names.  This allows row keys() to be compatible with further
+      SQLAlchemy operations.
+
+- firebird
+    - the keys() method of RowProxy() now returns the result column names *normalized*
+      to be SQLAlchemy case insensitive names. This means they will be lower case 
+      for case insensitive names, whereas the DBAPI would normally return them 
+      as UPPERCASE names.  This allows row keys() to be compatible with further
+      SQLAlchemy operations.
+      
+- new dialects
+    - postgresql+pg8000
+    - postgresql+pypostgresql (partial)
+    - postgresql+zxjdbc
+    - mysql+pyodbc
+    - mysql+zxjdbc
+
+- mssql
+    - the "has_window_funcs" flag is removed.  LIMIT/OFFSET usage will use ROW NUMBER as always,
+      and if on an older version of SQL Server, the operation fails.  The behavior is exactly
+      the same except the error is raised by SQL server instead of the dialect, and no
+      flag setting is required to enable it.
+    - the "auto_identity_insert" flag is removed.  This feature always takes effect
+      when an INSERT statement overrides a column that is known to have a sequence on it.
+      As with "has_window_funcs", if the underlying driver doesn't support this, then you 
+      can't do this operation in any case, so there's no point in having a flag.
+    - using new dialect.initialize() feature to set up version-dependent behavior.
+    
+- types
+    - PickleType now uses == for comparison of values when mutable=True, 
+      unless the "comparator" argument with a comparsion function is specified to the type.  
+      Objects being pickled will be compared based on identity (which defeats the purpose
+      of mutable=True) if __eq__() is not overridden or a comparison function is not provided.
+    - The default "precision" and "scale" arguments of Numeric and Float have been removed 
+      and now default to None.   NUMERIC and FLOAT will be rendered with no numeric arguments
+      by default unless these values are provided.
+    - AbstractType.get_search_list() is removed - the games that was used for are no 
+      longer necessary.
+      
+      
\ No newline at end of file
diff --git a/CHANGES b/CHANGES
index 3ff1ee032b5ead9a6c6afd9bb16d87b912ad44eb..c073159da8087d9b9a6bdc7e11291227953cd414 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -199,7 +199,7 @@ CHANGES
     - Repaired the printing of SQL exceptions which are not 
       based on parameters or are not executemany() style.
       
-- postgres
+- postgresql
     - Deprecated the hardcoded TIMESTAMP function, which when
       used as func.TIMESTAMP(value) would render "TIMESTAMP value".
       This breaks on some platforms as PostgreSQL doesn't allow
@@ -524,7 +524,7 @@ CHANGES
       fail on recent versions of pysqlite which raise 
       an error when fetchone() called with no rows present.
       
-- postgres
+- postgresql
     - Index reflection won't fail when an index with 
       multiple expressions is encountered.
       
index 99edfcacec5d62f5bb37c3dea9466a04565c03c3..92a7521d0280ed21c0128ac192a478e852abb2f9 100644 (file)
@@ -10,10 +10,19 @@ downloads for nose are available at:
 
 http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html
 
-
 SQLAlchemy implements a nose plugin that must be present when tests are run.
 This plugin is available when SQLAlchemy is installed via setuptools.
 
+INSTANT TEST RUNNER
+-------------------
+
+A plain vanilla run of all tests using sqlite can be run via setup.py:
+
+    $ python setup.py test
+    
+Setuptools will take care of the rest !   To run nose directly and have
+its full set of options available, read on...
+
 SETUP
 -----
 
@@ -67,11 +76,14 @@ DATABASE TARGETS
 Tests will target an in-memory SQLite database by default.  To test against
 another database, use the --dburi option with any standard SQLAlchemy URL:
 
-    --dburi=postgres://user:password@localhost/test
+    --dburi=postgresql://user:password@localhost/test
 
-Use an empty database and a database user with general DBA privileges.  The
-test suite will be creating and dropping many tables and other DDL, and
-preexisting tables will interfere with the tests
+Use an empty database and a database user with general DBA privileges.  
+The test suite will be creating and dropping many tables and other DDL, and
+preexisting tables will interfere with the tests.
+
+IMPORTANT !: please see TIPS at the end if your are testing on POSTGRESQL,
+ORACLE, or MSSQL - additional steps are required to prepare a test database.
 
 If you'll be running the tests frequently, database aliases can save a lot of
 typing.  The --dbs option lists the built-in aliases and their matching URLs:
@@ -80,19 +92,19 @@ typing.  The --dbs option lists the built-in aliases and their matching URLs:
     Available --db options (use --dburi to override)
                mysql    mysql://scott:tiger@127.0.0.1:3306/test
               oracle    oracle://scott:tiger@127.0.0.1:1521
-            postgres    postgres://scott:tiger@127.0.0.1:5432/test
+            postgresql    postgresql://scott:tiger@127.0.0.1:5432/test
     [...]
 
 To run tests against an aliased database:
 
-    $ nosetests --db=postgres
+    $ nosetests --db=postgresql
 
 To customize the URLs with your own users or hostnames, make a simple .ini
 file called `test.cfg` at the top level of the SQLAlchemy source distribution
 or a `.satest.cfg` in your home directory:
 
     [db]
-    postgres=postgres://myuser:mypass@localhost/mydb
+    postgresql=postgresql://myuser:mypass@localhost/mydb
 
 Your custom entries will override the defaults and you'll see them reflected
 in the output of --dbs.
@@ -159,13 +171,23 @@ IRC!
 
 TIPS
 ----
-PostgreSQL: The tests require an 'alt_schema' and 'alt_schema_2' to be present in
+
+PostgreSQL: The tests require an 'test_schema' and 'test_schema_2' to be present in
 the testing database.
 
-PostgreSQL: When running the tests on postgres, postgres can get slower and
-slower each time you run the tests.  This seems to be related to the constant
-creation/dropping of tables.  Running a "VACUUM FULL" on the database will
-speed it up again.
+Oracle: the database owner should be named "scott" (this will be fixed),
+and an additional "owner" named "ed" is required:
+
+1. create a user 'ed' in the oracle database.
+2. in 'ed', issue the following statements:
+    create table parent(id integer primary key, data varchar2(50));
+    create table child(id integer primary key, data varchar2(50), parent_id integer references parent(id));
+    create synonym ptable for parent;
+    create synonym ctable for child;
+    grant all on parent to scott;  (or to whoever you run the oracle tests as)
+    grant all on child to scott;  (same)
+    grant all on ptable to scott;
+    grant all on ctable to scott;
 
 MSSQL: Tests that involve multiple connections require Snapshot Isolation
 ability implented on the test database in order to prevent deadlocks that will
diff --git a/README_THIS_IS_06 b/README_THIS_IS_06
new file mode 100644 (file)
index 0000000..b83f886
--- /dev/null
@@ -0,0 +1,11 @@
+Trunk is now moved to SQLAlchemy 0.6.
+
+An ongoing wiki page of changes etc. is at:
+
+http://www.sqlalchemy.org/trac/wiki/06Migration
+
+SQLAlchemy 0.5 is now in a maintenance branch.  Get it at:
+
+http://svn.sqlalchemy.org/sqlalchemy/branches/rel_0_5
+
+
index b574c27a92a2d0ff8db90a9be8fc019712754ead..cb2c8c1a75e28053c38953fdbed1d144e4b3ed7c 100644 (file)
@@ -224,5 +224,7 @@ handlers.append((re.compile(r".*"), default))
 
 
 if __name__ == '__main__':
-    convert("test/orm/inheritance/abc_inheritance.py")
+    import sys
+    for f in sys.argv[1:]:
+        convert(f)
 #    walk()
index 227a54c9c83b8bba6c31259fc3aa429f8de04ddf..501b4ee757b4df081a09e9e84522248695b868f4 100644 (file)
@@ -4,7 +4,7 @@ Appendix:  Copyright
 
 This is the MIT license: `<http://www.opensource.org/licenses/mit-license.php>`_
 
-Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
+Copyright (c) 2005, 2006, 2007, 2008, 2009 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
 Bayer.
 
 Permission is hereby granted, free of charge, to any person obtaining a copy of this
index c7f924a9dc2d0eb16a5d296dee4648b450eecf13..df1088bcd2c481812a83e9664af93a6dc8f717a8 100644 (file)
@@ -19,9 +19,9 @@ Where above, a :class:`~sqlalchemy.engine.Engine` references both a  :class:`~sq
 
 Creating an engine is just a matter of issuing a single call, :func:`create_engine()`::
 
-    engine = create_engine('postgres://scott:tiger@localhost:5432/mydatabase')
+    engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase')
     
-The above engine invokes the ``postgres`` dialect and a connection pool which references ``localhost:5432``.
+The above engine invokes the ``postgresql`` dialect and a connection pool which references ``localhost:5432``.
 
 The engine can be used directly to issue SQL to the database.  The most generic way is to use connections, which you get via the ``connect()`` method::
 
@@ -52,7 +52,7 @@ The ``Engine`` and ``Connection`` can do a lot more than what we illustrated abo
 
 Supported Databases 
 ====================
-Recall that the ``Dialect`` is used to describe how to talk to a specific kind of database.  Dialects are included with SQLAlchemy for many different backends; these can be seen as a Python package within the :mod:`~sqlalchemy.databases` package.  Each dialect requires the appropriate DBAPI drivers to be installed separately.
+Recall that the ``Dialect`` is used to describe how to talk to a specific kind of database.  Dialects are included with SQLAlchemy for many different backends; these can be seen as a Python package within the :mod:`~sqlalchemy.dialect` package.  Each dialect requires the appropriate DBAPI drivers to be installed separately.
 
 Dialects included with SQLAlchemy fall under one of three categories: supported, experimental, and third party.  Supported drivers are those which work against the most common databases available in the open source world, including SQLite, PostgreSQL, MySQL, and Firebird.   Very popular commercial databases which provide easy access to test platforms are also supported, these currently include MSSQL and Oracle.   These dialects are tested frequently and the level of support should be close to 100% for each.
 
@@ -63,23 +63,22 @@ There are also third-party dialects available - currently IBM offers a DB2/Infor
 Downloads for each DBAPI at the time of this writing are as follows:
 
 * Supported Dialects
-
- - PostgreSQL:  `psycopg2 <http://www.initd.org/tracker/psycopg/wiki/PsycopgTwo>`_ 
+ - PostgreSQL:  `psycopg2 <http://www.initd.org/tracker/psycopg/wiki/PsycopgTwo>`_ `pg8000 <http://pybrary.net/pg8000/>`_
+ - PostgreSQL on Jython: `PostgreSQL JDBC Driver <http://jdbc.postgresql.org/>`_
  - SQLite:  `sqlite3 <http://www.python.org/doc/2.5.2/lib/module-sqlite3.html>`_ (included in Python 2.5 or greater) `pysqlite <http://initd.org/tracker/pysqlite>`_
  - MySQL:   `MySQLDB (a.k.a. mysql-python) <http://sourceforge.net/projects/mysql-python>`_
+ - MySQL on Jython: `JDBC Driver for MySQL <http://www.mysql.com/products/connector/>`_
  - Oracle:  `cx_Oracle <http://cx-oracle.sourceforge.net/>`_
  - Firebird:  `kinterbasdb <http://kinterbasdb.sourceforge.net/>`_
  - MS-SQL, MSAccess:  `pyodbc <http://pyodbc.sourceforge.net/>`_ (recommended) `adodbapi <http://adodbapi.sourceforge.net/>`_  `pymssql <http://pymssql.sourceforge.net/>`_
 
 * Experimental Dialects
-
  - MSAccess:  `pyodbc <http://pyodbc.sourceforge.net/>`_
  - Informix:  `informixdb <http://informixdb.sourceforge.net/>`_
  - Sybase:   TODO
  - MAXDB:    TODO
 
 * Third Party Dialects
-
  - DB2/Informix IDS: `ibm-db <http://code.google.com/p/ibm-db/>`_
 
 The SQLAlchemy Wiki contains a page of database notes, describing whatever quirks and behaviors have been observed.  Its a good place to check for issues with specific databases.  `Database Notes <http://www.sqlalchemy.org/trac/wiki/DatabaseNotes>`_
@@ -89,31 +88,42 @@ create_engine() URL Arguments
 
 SQLAlchemy indicates the source of an Engine strictly via `RFC-1738 <http://rfc.net/rfc1738.html>`_ style URLs, combined with optional keyword arguments to specify options for the Engine.  The form of the URL is:
 
-    driver://username:password@host:port/database
+    dialect+driver://username:password@host:port/database
+
+Dialect names include the identifying name of the SQLAlchemy dialect which include ``sqlite``, ``mysql``, ``postgresql``, ``oracle``, ``mssql``, and ``firebird``.  The drivername is the name of the DBAPI to be used to connect to the database using all lowercase letters.   If not specified, a "default" DBAPI will be imported if available - this default is typically the most widely known driver available for that backend (i.e. cx_oracle, pysqlite/sqlite3, psycopg2, mysqldb).   For Jython connections, the driver is always `zxjdbc`, which is the JDBC-DBAPI bridge included with Jython.
+
+.. sourcecode:: python+sql
+
+    # postgresql - psycopg2 is the default driver.
+    pg_db = create_engine('postgresql://scott:tiger@localhost/mydatabase')
+    pg_db = create_engine('postgresql+psycopg2://scott:tiger@localhost/mydatabase')
+    pg_db = create_engine('postgresql+pg8000://scott:tiger@localhost/mydatabase')
 
-Dialect names include the identifying name of the SQLAlchemy dialect which include ``sqlite``, ``mysql``, ``postgres``, ``oracle``, ``mssql``, and ``firebird``.  In SQLAlchemy 0.5 and earlier, the DBAPI implementation is automatically selected if more than one are available - currently this includes only MSSQL (pyodbc is the default, then adodbapi, then pymssql) and SQLite (sqlite3 is the default, or pysqlite if sqlite3 is not availble).   When using MSSQL, ``create_engine()`` accepts a ``module`` argument which specifies the name of the desired DBAPI to be used, overriding the default behavior.   
+    # postgresql on Jython
+    pg_db = create_engine('postgresql+zxjdbc://scott:tiger@localhost/mydatabase')
+    
+    # mysql - MySQLdb (mysql-python) is the default driver
+    mysql_db = create_engine('mysql://scott:tiger@localhost/foo')
+    mysql_db = create_engine('mysql+mysqldb://scott:tiger@localhost/foo')
+
+    # mysql on Jython
+    mysql_db = create_engine('mysql+zxjdbc://localhost/foo')
 
-  .. sourcecode:: python+sql
-  
-    # postgresql
-    pg_db = create_engine('postgres://scott:tiger@localhost/mydatabase')
+    # mysql with pyodbc (buggy)
+    mysql_db = create_engine('mysql+pyodbc://scott:tiger@some_dsn')
 
-    # mysql
-    mysql_db = create_engine('mysql://scott:tiger@localhost/mydatabase')
-  
-    # oracle
+    # oracle - cx_oracle is the default driver
     oracle_db = create_engine('oracle://scott:tiger@127.0.0.1:1521/sidname')
-  
+
     # oracle via TNS name
-    oracle_db = create_engine('oracle://scott:tiger@tnsname')
-  
+    oracle_db = create_engine('oracle+cx_oracle://scott:tiger@tnsname')
+
     # mssql using ODBC datasource names.  PyODBC is the default driver.
     mssql_db = create_engine('mssql://mydsn')
-    mssql_db = create_engine('mssql://scott:tiger@mydsn')
-    
-    # firebird
-    firebird_db = create_engine('firebird://scott:tiger@localhost/sometest.gdm')
-  
+    mssql_db = create_engine('mssql+pyodbc://mydsn')
+    mssql_db = create_engine('mssql+adodbapi://mydsn')
+    mssql_db = create_engine('mssql+pyodbc://username:password@mydsn')
+
 SQLite connects to file based databases.   The same URL format is used, omitting the hostname, and using the "file" portion as the filename of the database.   This has the effect of four slashes being present for an absolute file path::
 
     # sqlite://<nohostname>/<path>
@@ -132,12 +142,11 @@ The :class:`~sqlalchemy.engine.base.Engine` will ask the connection pool for a c
 Custom DBAPI connect() arguments
 --------------------------------
 
-
 Custom arguments used when issuing the ``connect()`` call to the underlying DBAPI may be issued in three distinct ways.  String-based arguments can be passed directly from the URL string as query arguments:
 
 .. sourcecode:: python+sql
 
-    db = create_engine('postgres://scott:tiger@localhost/test?argument1=foo&argument2=bar')
+    db = create_engine('postgresql://scott:tiger@localhost/test?argument1=foo&argument2=bar')
 
 If SQLAlchemy's database connector is aware of a particular query argument, it may convert its type from string to its proper type.
     
@@ -145,7 +154,7 @@ If SQLAlchemy's database connector is aware of a particular query argument, it m
 
 .. sourcecode:: python+sql
 
-    db = create_engine('postgres://scott:tiger@localhost/test', connect_args = {'argument1':17, 'argument2':'bar'})
+    db = create_engine('postgresql://scott:tiger@localhost/test', connect_args = {'argument1':17, 'argument2':'bar'})
 
 The most customizable connection method of all is to pass a ``creator`` argument, which specifies a callable that returns a DBAPI connection:
 
@@ -154,7 +163,7 @@ The most customizable connection method of all is to pass a ``creator`` argument
     def connect():
         return psycopg.connect(user='scott', host='localhost')
 
-    db = create_engine('postgres://', creator=connect)
+    db = create_engine('postgresql://', creator=connect)
 
 .. _create_engine_args:
 
@@ -165,7 +174,7 @@ Keyword options can also be specified to ``create_engine()``, following the stri
 
 .. sourcecode:: python+sql
 
-    db = create_engine('postgres://...', encoding='latin1', echo=True)
+    db = create_engine('postgresql://...', encoding='latin1', echo=True)
 
 Options common to all database dialects are described at :func:`~sqlalchemy.create_engine`.
 
index f79c637ee0a38c41d3926205fd773774179ce9e2..464b764bf18ac004bc70f2e793008e4bc3355f02 100644 (file)
@@ -396,7 +396,7 @@ The above SQL functions are usually executed "inline" with the INSERT or UPDATE
 
 * the ``inline=True`` flag is not set on the ``Insert()`` or ``Update()`` construct.
 
-For a statement execution which is not an executemany, the returned ``ResultProxy`` will contain a collection accessible via ``result.postfetch_cols()`` which contains a list of all ``Column`` objects which had an inline-executed default.  Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the ``last_inserted_params()`` or ``last_updated_params()`` collections on ``ResultProxy``.  The ``last_inserted_ids()`` collection contains a list of primary key values for the row inserted.  
+For a statement execution which is not an executemany, the returned ``ResultProxy`` will contain a collection accessible via ``result.postfetch_cols()`` which contains a list of all ``Column`` objects which had an inline-executed default.  Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the ``last_inserted_params()`` or ``last_updated_params()`` collections on ``ResultProxy``.  The ``inserted_primary_key`` collection contains a list of primary key values for the row inserted.  
 
 DDL-Level Defaults 
 -------------------
index cd635aaa09417717dd51f86c33e3a797b7573ac3..52a2ee3710a47fc6783aacef20c2c433ca1d1213 100644 (file)
@@ -1,4 +1,4 @@
-Access
-======
+Microsoft Access
+================
 
-.. automodule:: sqlalchemy.databases.access
+.. automodule:: sqlalchemy.dialects.access.base
index 19a2c4f918dd905597c13b454646c1474224b225..54c38f49b0e84b1993c365cd774da2e517643f5e 100644 (file)
@@ -1,4 +1,4 @@
 Firebird
 ========
 
-.. automodule:: sqlalchemy.databases.firebird
+.. automodule:: sqlalchemy.dialects.firebird.base
index fe9f2539523043854bfa3dbdc4fa0ef82eb779cd..f9c4df5ce87b13e8af178a1795b394d27ead14fa 100644 (file)
@@ -3,17 +3,33 @@
 sqlalchemy.databases
 ====================
 
+Supported Databases
+-------------------
+
+These backends are fully operational with 
+current versions of SQLAlchemy.
+
 .. toctree::
     :glob:
 
-    access
     firebird
-    informix
-    maxdb
     mssql
     mysql
     oracle
-    postgres
+    postgresql
     sqlite
+
+Unsupported Databases
+---------------------
+
+These backends are untested and may not be completely
+ported to current versions of SQLAlchemy.
+
+.. toctree::
+    :glob:
+
+    access
+    informix
+    maxdb
     sybase
 
index 9f787e3c2916fcd5861e3db616727b4c8087602e..7cf271d0b7903a6212126161dfbf60dabe020ae4 100644 (file)
@@ -1,4 +1,4 @@
 Informix
 ========
 
-.. automodule:: sqlalchemy.databases.informix
+.. automodule:: sqlalchemy.dialects.informix.base
index b137da917c671e09ea49b56a2eb3214a8fc60007..3edd55a775c415844df3085f7db2e2333987ab51 100644 (file)
@@ -1,4 +1,4 @@
 MaxDB
 =====
 
-.. automodule:: sqlalchemy.databases.maxdb
+.. automodule:: sqlalchemy.dialects.maxdb.base
index a55ab85a95f2205f94c9b59732b98529a9a5667b..68c0f0462c1a76d02a264b165b0d9f231e11dd9d 100644 (file)
@@ -1,4 +1,18 @@
-SQL Server
-==========
+Microsoft SQL Server
+====================
+
+.. automodule:: sqlalchemy.dialects.mssql.base
+
+PyODBC
+------
+.. automodule:: sqlalchemy.dialects.mssql.pyodbc
+
+AdoDBAPI
+--------
+.. automodule:: sqlalchemy.dialects.mssql.adodbapi
+
+pymssql
+-------
+.. automodule:: sqlalchemy.dialects.mssql.pymssql
+
 
-.. automodule:: sqlalchemy.databases.mssql
index 28f905343f4cede51cd448ce028f0c570ea02833..839b8cae075f2f9828423259f229ec5bb797bf2b 100644 (file)
 MySQL
 =====
 
-.. automodule:: sqlalchemy.databases.mysql
+.. automodule:: sqlalchemy.dialects.mysql.base
 
 MySQL Column Types
 ------------------
 
-.. autoclass:: MSNumeric
+.. autoclass:: NUMERIC
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSDecimal
+.. autoclass:: DECIMAL
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSDouble
+.. autoclass:: DOUBLE
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSReal
+.. autoclass:: REAL
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSFloat
+.. autoclass:: FLOAT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSInteger
+.. autoclass:: INTEGER
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSBigInteger
+.. autoclass:: BIGINT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSMediumInteger
+.. autoclass:: MEDIUMINT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSTinyInteger
+.. autoclass:: TINYINT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSSmallInteger
+.. autoclass:: SMALLINT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSBit
+.. autoclass:: BIT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSDateTime
+.. autoclass:: DATETIME
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSDate
+.. autoclass:: DATE
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSTime
+.. autoclass:: TIME
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSTimeStamp
+.. autoclass:: TIMESTAMP
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSYear
+.. autoclass:: YEAR
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSText
+.. autoclass:: TEXT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSTinyText
+.. autoclass:: TINYTEXT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSMediumText
+.. autoclass:: MEDIUMTEXT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSLongText
+.. autoclass:: LONGTEXT
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSString
+.. autoclass:: VARCHAR
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSChar
+.. autoclass:: CHAR
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSNVarChar
+.. autoclass:: NVARCHAR
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSNChar
+.. autoclass:: NCHAR
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSVarBinary
+.. autoclass:: VARBINARY
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSBinary
+.. autoclass:: BINARY
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSBlob
+.. autoclass:: BLOB
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSTinyBlob
+.. autoclass:: TINYBLOB
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSMediumBlob
+.. autoclass:: MEDIUMBLOB
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSLongBlob
+.. autoclass:: LONGBLOB
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSEnum
+.. autoclass:: ENUM
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSSet
+.. autoclass:: SET
    :members: __init__
    :show-inheritance:
 
-.. autoclass:: MSBoolean
+.. autoclass:: BOOLEAN
    :members: __init__
    :show-inheritance:
 
+MySQLdb Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.mysql.mysqldb
+
+zxjdbc Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.mysql.zxjdbc
index 188f6f438314a4e0e43e57ac3665c4057a554f19..584dfbf8145ac060ba5267afe98ed5e8811badf0 100644 (file)
@@ -1,4 +1,10 @@
 Oracle
 ======
 
-.. automodule:: sqlalchemy.databases.oracle
+.. automodule:: sqlalchemy.dialects.oracle.base
+
+cx_Oracle Notes
+---------------
+
+.. automodule:: sqlalchemy.dialects.oracle.cx_oracle
+
diff --git a/doc/build/reference/dialects/postgres.rst b/doc/build/reference/dialects/postgres.rst
deleted file mode 100644 (file)
index 7cf0723..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-PostgreSQL
-==========
-
-.. automodule:: sqlalchemy.databases.postgres
diff --git a/doc/build/reference/dialects/postgresql.rst b/doc/build/reference/dialects/postgresql.rst
new file mode 100644 (file)
index 0000000..7e00645
--- /dev/null
@@ -0,0 +1,20 @@
+PostgreSQL
+==========
+
+.. automodule:: sqlalchemy.dialects.postgresql.base
+
+psycopg2 Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.postgresql.psycopg2
+
+
+pg8000 Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.postgresql.pg8000
+
+zxjdbc Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.postgresql.zxjdbc
index 118c239b1db4e584815898be437b70ba5102835f..8361876c38ebfba80bda3ceb758c0463ca85ce13 100644 (file)
@@ -1,5 +1,9 @@
 SQLite
 ======
 
-.. automodule:: sqlalchemy.databases.sqlite
+.. automodule:: sqlalchemy.dialects.sqlite.base
 
+Pysqlite
+--------
+
+.. automodule:: sqlalchemy.dialects.sqlite.pysqlite
\ No newline at end of file
index fac1a1f6b4e74ea706b601ab49dedb3387d02fcb..1b7651d2cff2a169fb810ecb5573f9b04b6d351f 100644 (file)
@@ -1,4 +1,4 @@
 Sybase
 ======
 
-.. automodule:: sqlalchemy.databases.sybase
+.. automodule:: sqlalchemy.dialects.sybase.base
index 2f861816c3d2713564846bd257c7d64f69150332..394fa864ce5e8e089679c893032dcc402a4774c6 100644 (file)
@@ -65,7 +65,3 @@ Internals
 .. autoclass:: ExecutionContext
     :members:
 
-.. autoclass:: SchemaIterator
-    :members:
-    :show-inheritance:
-    
index 91e96819780b1371087e9f4ede6f6eb7bd3043fd..d37425e3a6e1fae3898844b60c4b06c319d31707 100644 (file)
@@ -32,7 +32,7 @@ directly to :func:`~sqlalchemy.create_engine` as keyword arguments:
 ``pool_size``, ``max_overflow``, ``pool_recycle`` and
 ``pool_timeout``.  For example::
 
-  engine = create_engine('postgres://me@localhost/mydb',
+  engine = create_engine('postgresql://me@localhost/mydb',
                          pool_size=20, max_overflow=0)
 
 In the case of SQLite, a :class:`SingletonThreadPool` is provided instead,
index afe509d74b6f8ed2aec1de9f33602ff99acd9b09..6eb532fe18f28c446ebcc1a4b7aefe58fd8513a6 100644 (file)
@@ -153,22 +153,36 @@ reference for the database you're interested in.
 For example, MySQL has a ``BIGINTEGER`` type and PostgreSQL has an
 ``INET`` type.  To use these, import them from the module explicitly::
 
-    from sqlalchemy.databases.mysql import MSBigInteger, MSEnum
+    from sqlalchemy.dialect.mysql import dialect as mysql
 
     table = Table('foo', meta,
-        Column('id', MSBigInteger),
-        Column('enumerates', MSEnum('a', 'b', 'c'))
+        Column('id', mysql.BIGINTEGER),
+        Column('enumerates', mysql.ENUM('a', 'b', 'c'))
     )
 
 Or some PostgreSQL types::
 
-    from sqlalchemy.databases.postgres import PGInet, PGArray
+    from sqlalchemy.dialect.postgresql import dialect as postgresql
 
     table = Table('foo', meta,
-        Column('ipaddress', PGInet),
-        Column('elements', PGArray(str))
+        Column('ipaddress', postgresql.INET),
+        Column('elements', postgresql.ARRAY(str))
         )
 
+Each dialect should provide the full set of typenames supported by
+that backend, so that a backend-specific schema can be created without
+the need to locate types::
+
+    from sqlalchemy.dialects.postgresql import dialect as pg
+
+    t = Table('mytable', metadata,
+               Column('id', pg.INTEGER, primary_key=True),
+               Column('name', pg.VARCHAR(300)),
+               Column('inetaddr', pg.INET)
+    )
+
+Where above, the INTEGER and VARCHAR types are ultimately from 
+sqlalchemy.types, but the Postgresql dialect makes them available.
 
 Custom Types
 ------------
@@ -181,7 +195,7 @@ The simplest method is implementing a :class:`TypeDecorator`, a helper
 class that makes it easy to augment the bind parameter and result
 processing capabilities of one of the built in types.
 
-To build a type object from scratch, subclass `:class:TypeEngine`.
+To build a type object from scratch, subclass `:class:UserDefinedType`.
 
 .. autoclass:: TypeDecorator
    :members:
@@ -189,6 +203,12 @@ To build a type object from scratch, subclass `:class:TypeEngine`.
    :inherited-members:
    :show-inheritance:
 
+.. autoclass:: UserDefinedType
+   :members:
+   :undoc-members:
+   :inherited-members:
+   :show-inheritance:
+
 .. autoclass:: TypeEngine
    :members:
    :undoc-members:
index b2b66c32fe3739f24de5f5106d267d573fbb0129..c704dc7928377f5ab9f85b14c8381397751ef9e6 100644 (file)
@@ -54,7 +54,7 @@ In our previous example regarding ``sessionmaker()``, we specified a ``bind`` fo
     Session = sessionmaker()
 
     # later, we create the engine
-    engine = create_engine('postgres://...')
+    engine = create_engine('postgresql://...')
     
     # associate it with our custom Session class
     Session.configure(bind=engine)
@@ -74,7 +74,7 @@ The ``Session`` can also be explicitly bound to an individual database ``Connect
     # global application scope.  create Session class, engine
     Session = sessionmaker()
 
-    engine = create_engine('postgres://...')
+    engine = create_engine('postgresql://...')
     
     ...
     
@@ -219,7 +219,7 @@ With ``merge()``, the given instance is not placed within the session, and can b
   * An application which reads an object structure from a file and wishes to save it to the database might parse the file, build up the structure, and then use ``merge()`` to save it to the database, ensuring that the data within the file is used to formulate the primary key of each element of the structure.  Later, when the file has changed, the same process can be re-run, producing a slightly different object structure, which can then be ``merged()`` in again, and the ``Session`` will automatically update the database to reflect those changes.
   * A web application stores mapped entities within an HTTP session object.  When each request starts up, the serialized data can be merged into the session, so that the original entity may be safely shared among requests and threads.
 
-``merge()`` is frequently used by applications which implement their own second level caches.  This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time.  When such an object needs to exist within a ``Session``, ``merge()`` is a good choice since it leaves the original cached object untouched.  For this use case, merge provides a keyword option called ``dont_load=True``.  When this boolean flag is set to ``True``, ``merge()`` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead.   The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued.
+``merge()`` is frequently used by applications which implement their own second level caches.  This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time.  When such an object needs to exist within a ``Session``, ``merge()`` is a good choice since it leaves the original cached object untouched.  For this use case, merge provides a keyword option called ``load=False``.  When this boolean flag is set to ``False``, ``merge()`` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead.   The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued.
 
 Deleting
 --------
@@ -459,8 +459,8 @@ Enabling Two-Phase Commit
 
 Finally, for MySQL, PostgreSQL, and soon Oracle as well, the session can be instructed to use two-phase commit semantics. This will coordinate the committing of transactions across databases so that the transaction is either committed or rolled back in all databases. You can also ``prepare()`` the session for interacting with transactions not managed by SQLAlchemy. To use two phase transactions set the flag ``twophase=True`` on the session::
 
-    engine1 = create_engine('postgres://db1')
-    engine2 = create_engine('postgres://db2')
+    engine1 = create_engine('postgresql://db1')
+    engine2 = create_engine('postgresql://db2')
     
     Session = sessionmaker(twophase=True)
 
@@ -549,7 +549,7 @@ Note that above, we issue a ``commit()`` both on the ``Session`` as well as the
 
 When using the ``threadlocal`` engine context, the process above is simplified; the ``Session`` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it::
 
-    engine = create_engine('postgres://mydb', strategy="threadlocal")
+    engine = create_engine('postgresql://mydb', strategy="threadlocal")
     engine.begin()
     
     session = Session()  # session takes place in the transaction like everyone else
@@ -652,8 +652,8 @@ Vertical Partitioning
 
 Vertical partitioning places different kinds of objects, or different tables, across multiple databases::
 
-    engine1 = create_engine('postgres://db1')
-    engine2 = create_engine('postgres://db2')
+    engine1 = create_engine('postgresql://db1')
+    engine2 = create_engine('postgresql://db2')
 
     Session = sessionmaker(twophase=True)
 
index 387013cacc9bee02bdbc53c9c2a7974b7fb27529..2bcaa631d22652a5d0e5bd6f9950ece398bf909e 100644 (file)
@@ -143,10 +143,10 @@ What about the ``result`` variable we got when we called ``execute()`` ?  As the
 
 .. sourcecode:: pycon+sql
 
-    >>> result.last_inserted_ids()
+    >>> result.inserted_primary_key
     [1]
     
-The value of ``1`` was automatically generated by SQLite, but only because we did not specify the ``id`` column in our ``Insert`` statement; otherwise, our explicit value would have been used.   In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' ``Dialect`` knows the specific steps needed to determine the correct value (or values; note that ``last_inserted_ids()`` returns a list so that it supports composite primary keys).
+The value of ``1`` was automatically generated by SQLite, but only because we did not specify the ``id`` column in our ``Insert`` statement; otherwise, our explicit value would have been used.   In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' ``Dialect`` knows the specific steps needed to determine the correct value (or values; note that ``inserted_primary_key`` returns a list so that it supports composite primary keys).
 
 Executing Multiple Statements 
 ==============================
index 0a344e98efd0871f1cbd951621fe062fc0528297..1f57e327207d249868ec1c9b9b56929c2a79905d 100644 (file)
@@ -55,7 +55,7 @@ def teststring(s, name, globs=None, verbose=None, report=True,
     return runner.failures, runner.tries
 
 def replace_file(s, newfile):
-    engine = r"'(sqlite|postgres|mysql):///.*'"
+    engine = r"'(sqlite|postgresql|mysql):///.*'"
     engine = re.compile(engine, re.MULTILINE)
     s, n = re.subn(engine, "'sqlite:///" + newfile + "'", s)
     if not n:
index c482d82560f321ca0fb357aaa1fc042050213dc9..8e687d7f8caa6990db7f664d880235c863fd93e1 100644 (file)
@@ -231,7 +231,7 @@ if __name__ == '__main__':
     from sqlalchemy.orm import sessionmaker, column_property
     from sqlalchemy.ext.declarative import declarative_base
 
-    engine = create_engine('postgres://scott:tiger@localhost/gistest', echo=True)
+    engine = create_engine('postgresql://scott:tiger@localhost/gistest', echo=True)
     metadata = MetaData(engine)
     Base = declarative_base(metadata=metadata)
 
index 92d48e2d787b87ea43b94c1fb70cfab886f84a70..00a4cc3ec2d24be190a2721f593949ae7142885a 100644 (file)
@@ -27,7 +27,7 @@ class CachingQuery(Query):
                     self.session.expunge(x)
                 _cache[self.cachekey] = ret
 
-            return iter(self.session.merge(x, dont_load=True) for x in ret)
+            return iter(self.session.merge(x, load=False) for x in ret)
 
         else:
             return Query.__iter__(self)
diff --git a/ez_setup.py b/ez_setup.py
new file mode 100644 (file)
index 0000000..d24e845
--- /dev/null
@@ -0,0 +1,276 @@
+#!python
+"""Bootstrap setuptools installation
+
+If you want to use setuptools in your package's setup.py, just include this
+file in the same directory with it, and add this to the top of your setup.py::
+
+    from ez_setup import use_setuptools
+    use_setuptools()
+
+If you want to require a specific version of setuptools, set a download
+mirror, or use an alternate download directory, you can do so by supplying
+the appropriate options to ``use_setuptools()``.
+
+This file can also be run as a script to install or upgrade setuptools.
+"""
+import sys
+DEFAULT_VERSION = "0.6c9"
+DEFAULT_URL     = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
+
+md5_data = {
+    'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca',
+    'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb',
+    'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b',
+    'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a',
+    'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618',
+    'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac',
+    'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5',
+    'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
+    'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
+    'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
+    'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
+    'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
+    'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
+    'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e',
+    'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e',
+    'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f',
+    'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2',
+    'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc',
+    'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167',
+    'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64',
+    'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d',
+    'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20',
+    'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab',
+    'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53',
+    'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2',
+    'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e',
+    'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372',
+    'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902',
+    'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de',
+    'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b',
+    'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03',
+    'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a',
+    'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6',
+    'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a',
+}
+
+import sys, os
+try: from hashlib import md5
+except ImportError: from md5 import md5
+
+def _validate_md5(egg_name, data):
+    if egg_name in md5_data:
+        digest = md5(data).hexdigest()
+        if digest != md5_data[egg_name]:
+            print >>sys.stderr, (
+                "md5 validation of %s failed!  (Possible download problem?)"
+                % egg_name
+            )
+            sys.exit(2)
+    return data
+
+def use_setuptools(
+    version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
+    download_delay=15
+):
+    """Automatically find/download setuptools and make it available on sys.path
+
+    `version` should be a valid setuptools version number that is available
+    as an egg for download under the `download_base` URL (which should end with
+    a '/').  `to_dir` is the directory where setuptools will be downloaded, if
+    it is not already available.  If `download_delay` is specified, it should
+    be the number of seconds that will be paused before initiating a download,
+    should one be required.  If an older version of setuptools is installed,
+    this routine will print a message to ``sys.stderr`` and raise SystemExit in
+    an attempt to abort the calling script.
+    """
+    was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules
+    def do_download():
+        egg = download_setuptools(version, download_base, to_dir, download_delay)
+        sys.path.insert(0, egg)
+        import setuptools; setuptools.bootstrap_install_from = egg
+    try:
+        import pkg_resources
+    except ImportError:
+        return do_download()       
+    try:
+        pkg_resources.require("setuptools>="+version); return
+    except pkg_resources.VersionConflict, e:
+        if was_imported:
+            print >>sys.stderr, (
+            "The required version of setuptools (>=%s) is not available, and\n"
+            "can't be installed while this script is running. Please install\n"
+            " a more recent version first, using 'easy_install -U setuptools'."
+            "\n\n(Currently using %r)"
+            ) % (version, e.args[0])
+            sys.exit(2)
+        else:
+            del pkg_resources, sys.modules['pkg_resources']    # reload ok
+            return do_download()
+    except pkg_resources.DistributionNotFound:
+        return do_download()
+
+def download_setuptools(
+    version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
+    delay = 15
+):
+    """Download setuptools from a specified location and return its filename
+
+    `version` should be a valid setuptools version number that is available
+    as an egg for download under the `download_base` URL (which should end
+    with a '/'). `to_dir` is the directory where the egg will be downloaded.
+    `delay` is the number of seconds to pause before an actual download attempt.
+    """
+    import urllib2, shutil
+    egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3])
+    url = download_base + egg_name
+    saveto = os.path.join(to_dir, egg_name)
+    src = dst = None
+    if not os.path.exists(saveto):  # Avoid repeated downloads
+        try:
+            from distutils import log
+            if delay:
+                log.warn("""
+---------------------------------------------------------------------------
+This script requires setuptools version %s to run (even to display
+help).  I will attempt to download it for you (from
+%s), but
+you may need to enable firewall access for this script first.
+I will start the download in %d seconds.
+
+(Note: if this machine does not have network access, please obtain the file
+
+   %s
+
+and place it in this directory before rerunning this script.)
+---------------------------------------------------------------------------""",
+                    version, download_base, delay, url
+                ); from time import sleep; sleep(delay)
+            log.warn("Downloading %s", url)
+            src = urllib2.urlopen(url)
+            # Read/write all in one block, so we don't create a corrupt file
+            # if the download is interrupted.
+            data = _validate_md5(egg_name, src.read())
+            dst = open(saveto,"wb"); dst.write(data)
+        finally:
+            if src: src.close()
+            if dst: dst.close()
+    return os.path.realpath(saveto)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+def main(argv, version=DEFAULT_VERSION):
+    """Install or upgrade setuptools and EasyInstall"""
+    try:
+        import setuptools
+    except ImportError:
+        egg = None
+        try:
+            egg = download_setuptools(version, delay=0)
+            sys.path.insert(0,egg)
+            from setuptools.command.easy_install import main
+            return main(list(argv)+[egg])   # we're done here
+        finally:
+            if egg and os.path.exists(egg):
+                os.unlink(egg)
+    else:
+        if setuptools.__version__ == '0.0.1':
+            print >>sys.stderr, (
+            "You have an obsolete version of setuptools installed.  Please\n"
+            "remove it from your system entirely before rerunning this script."
+            )
+            sys.exit(2)
+
+    req = "setuptools>="+version
+    import pkg_resources
+    try:
+        pkg_resources.require(req)
+    except pkg_resources.VersionConflict:
+        try:
+            from setuptools.command.easy_install import main
+        except ImportError:
+            from easy_install import main
+        main(list(argv)+[download_setuptools(delay=0)])
+        sys.exit(0) # try to force an exit
+    else:
+        if argv:
+            from setuptools.command.easy_install import main
+            main(argv)
+        else:
+            print "Setuptools version",version,"or greater has been installed."
+            print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)'
+
+def update_md5(filenames):
+    """Update our built-in md5 registry"""
+
+    import re
+
+    for name in filenames:
+        base = os.path.basename(name)
+        f = open(name,'rb')
+        md5_data[base] = md5(f.read()).hexdigest()
+        f.close()
+
+    data = ["    %r: %r,\n" % it for it in md5_data.items()]
+    data.sort()
+    repl = "".join(data)
+
+    import inspect
+    srcfile = inspect.getsourcefile(sys.modules[__name__])
+    f = open(srcfile, 'rb'); src = f.read(); f.close()
+
+    match = re.search("\nmd5_data = {\n([^}]+)}", src)
+    if not match:
+        print >>sys.stderr, "Internal error!"
+        sys.exit(2)
+
+    src = src[:match.start(1)] + repl + src[match.end(1):]
+    f = open(srcfile,'w')
+    f.write(src)
+    f.close()
+
+
+if __name__=='__main__':
+    if len(sys.argv)>2 and sys.argv[1]=='--md5update':
+        update_md5(sys.argv[2:])
+    else:
+        main(sys.argv[1:])
+
+
+
+
+
+
index ddbbb7b7eddb5cd0e12e54c9766d0ca25cb33e5e..31469ee5ae45e186334aa1291948e59326e030fb 100644 (file)
@@ -10,40 +10,6 @@ import sys
 import sqlalchemy.exc as exceptions
 sys.modules['sqlalchemy.exceptions'] = exceptions
 
-from sqlalchemy.types import (
-    BLOB,
-    BOOLEAN,
-    Binary,
-    Boolean,
-    CHAR,
-    CLOB,
-    DATE,
-    DATETIME,
-    DECIMAL,
-    Date,
-    DateTime,
-    FLOAT,
-    Float,
-    INT,
-    Integer,
-    Interval,
-    NCHAR,
-    NUMERIC,
-    Numeric,
-    PickleType,
-    SMALLINT,
-    SmallInteger,
-    String,
-    TEXT,
-    TIME,
-    TIMESTAMP,
-    Text,
-    Time,
-    Unicode,
-    UnicodeText,
-    VARCHAR,
-    )
-
 from sqlalchemy.sql import (
     alias,
     and_,
@@ -81,6 +47,43 @@ from sqlalchemy.sql import (
     update,
     )
 
+from sqlalchemy.types import (
+    BLOB,
+    BOOLEAN,
+    Binary,
+    Boolean,
+    CHAR,
+    CLOB,
+    DATE,
+    DATETIME,
+    DECIMAL,
+    Date,
+    DateTime,
+    FLOAT,
+    Float,
+    INT,
+    INTEGER,
+    Integer,
+    Interval,
+    NCHAR,
+    NVARCHAR,
+    NUMERIC,
+    Numeric,
+    PickleType,
+    SMALLINT,
+    SmallInteger,
+    String,
+    TEXT,
+    TIME,
+    TIMESTAMP,
+    Text,
+    Time,
+    Unicode,
+    UnicodeText,
+    VARCHAR,
+    )
+
+
 from sqlalchemy.schema import (
     CheckConstraint,
     Column,
@@ -107,6 +110,6 @@ from sqlalchemy.engine import create_engine, engine_from_config
 __all__ = sorted(name for name, obj in locals().items()
                  if not (name.startswith('_') or inspect.ismodule(obj)))
                  
-__version__ = '0.5.5'
+__version__ = '0.6beta1'
 
 del inspect, sys
diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py
new file mode 100644 (file)
index 0000000..f1383ad
--- /dev/null
@@ -0,0 +1,6 @@
+
+
+class Connector(object):
+    pass
+    
+    
\ No newline at end of file
diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py
new file mode 100644 (file)
index 0000000..a0f3f02
--- /dev/null
@@ -0,0 +1,24 @@
+from sqlalchemy.connectors import Connector
+
+class MxODBCConnector(Connector):
+    driver='mxodbc'
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    supports_unicode_statements = False
+    supports_unicode_binds = False
+
+    @classmethod
+    def import_dbapi(cls):
+        import mxODBC as module
+        return module
+
+    def create_connect_args(self, url):
+        '''Return a tuple of *args,**kwargs'''
+        # FIXME: handle mx.odbc.Windows proprietary args
+        opts = url.translate_connect_args(username='user')
+        opts.update(url.query)
+        argsDict = {}
+        argsDict['user'] = opts['user']
+        argsDict['password'] = opts['password']
+        connArgs = [[opts['dsn']], argsDict]
+        return connArgs
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
new file mode 100644 (file)
index 0000000..4f8d6d5
--- /dev/null
@@ -0,0 +1,80 @@
+from sqlalchemy.connectors import Connector
+
+import sys
+import re
+import urllib
+
+class PyODBCConnector(Connector):
+    driver='pyodbc'
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    # PyODBC unicode is broken on UCS-4 builds
+    supports_unicode = sys.maxunicode == 65535
+    supports_unicode_statements = supports_unicode
+    default_paramstyle = 'named'
+    
+    # for non-DSN connections, this should
+    # hold the desired driver name
+    pyodbc_driver_name = None
+    
+    @classmethod
+    def dbapi(cls):
+        return __import__('pyodbc')
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        opts.update(url.query)
+        
+        keys = opts
+        query = url.query
+
+        if 'odbc_connect' in keys:
+            connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
+        else:
+            dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
+            if dsn_connection:
+                connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
+            else:
+                port = ''
+                if 'port' in keys and not 'port' in query:
+                    port = ',%d' % int(keys.pop('port'))
+
+                connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
+                              'Server=%s%s' % (keys.pop('host', ''), port),
+                              'Database=%s' % keys.pop('database', '') ]
+
+            user = keys.pop("user", None)
+            if user:
+                connectors.append("UID=%s" % user)
+                connectors.append("PWD=%s" % keys.pop('password', ''))
+            else:
+                connectors.append("TrustedConnection=Yes")
+
+            # if set to 'Yes', the ODBC layer will try to automagically convert 
+            # textual data from your database encoding to your client encoding 
+            # This should obviously be set to 'No' if you query a cp1253 encoded 
+            # database from a latin1 client... 
+            if 'odbc_autotranslate' in keys:
+                connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
+
+            connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
+        return [[";".join (connectors)], {}]
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.ProgrammingError):
+            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
+        elif isinstance(e, self.dbapi.Error):
+            return '[08S01]' in str(e)
+        else:
+            return False
+
+    def _get_server_version_info(self, connection):
+        dbapi_con = connection.connection
+        version = []
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py
new file mode 100644 (file)
index 0000000..3cdfeb3
--- /dev/null
@@ -0,0 +1,43 @@
+import sys
+from sqlalchemy.connectors import Connector
+
+class ZxJDBCConnector(Connector):
+    driver = 'zxjdbc'
+    
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    
+    supports_unicode_binds = True
+    supports_unicode_statements = sys.version > '2.5.0+'
+    description_encoding = None
+    default_paramstyle = 'qmark'
+    
+    jdbc_db_name = None
+    jdbc_driver_name = None
+    
+    @classmethod
+    def dbapi(cls):
+        from com.ziclix.python.sql import zxJDBC
+        return zxJDBC
+
+    def _driver_kwargs(self):
+        """return kw arg dict to be sent to connect()."""
+        return {}
+        
+    def create_connect_args(self, url):
+        hostname = url.host
+        dbname = url.database
+        d, u, p, v = ("jdbc:%s://%s/%s" % (self.jdbc_db_name, hostname, dbname),
+                      url.username, url.password, self.jdbc_driver_name)
+        return [[d, u, p, v], self._driver_kwargs()]
+        
+    def is_disconnect(self, e):
+        if not isinstance(e, self.dbapi.ProgrammingError):
+            return False
+        e = str(e)
+        return 'connection is closed' in e or 'cursor is closed' in e
+
+    def _get_server_version_info(self, connection):
+        # use connection.connection.dbversion, and parse appropriately
+        # to get a tuple
+        raise NotImplementedError()
index 6588be0ae71410a180f9e13912c37c50e4a2088a..16cabd47f8b8cd179e4a24d25f70861cefa12da0 100644 (file)
@@ -4,6 +4,20 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+from sqlalchemy.dialects.sqlite import base as sqlite
+from sqlalchemy.dialects.postgresql import base as postgresql
+postgres = postgresql
+from sqlalchemy.dialects.mysql import base as mysql
+from sqlalchemy.dialects.oracle import base as oracle
+from sqlalchemy.dialects.firebird import base as firebird
+from sqlalchemy.dialects.maxdb import base as maxdb
+from sqlalchemy.dialects.informix import base as informix
+from sqlalchemy.dialects.mssql import base as mssql
+from sqlalchemy.dialects.access import base as access
+from sqlalchemy.dialects.sybase import base as sybase
+
+
+
 
 __all__ = (
     'access',
@@ -12,8 +26,8 @@ __all__ = (
     'maxdb',
     'mssql',
     'mysql',
-    'oracle',
-    'postgres',
+    'postgresql',
     'sqlite',
+    'oracle',
     'sybase',
     )
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
deleted file mode 100644 (file)
index 8a8d02d..0000000
+++ /dev/null
@@ -1,768 +0,0 @@
-# firebird.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-Firebird backend
-================
-
-This module implements the Firebird backend, thru the kinterbasdb_
-DBAPI module.
-
-Firebird dialects
------------------
-
-Firebird offers two distinct dialects_ (not to be confused with the
-SA ``Dialect`` thing):
-
-dialect 1
-  This is the old syntax and behaviour, inherited from Interbase pre-6.0.
-
-dialect 3
-  This is the newer and supported syntax, introduced in Interbase 6.0.
-
-From the user point of view, the biggest change is in date/time
-handling: under dialect 1, there's a single kind of field, ``DATE``
-with a synonim ``DATETIME``, that holds a `timestamp` value, that is a
-date with hour, minute, second. Under dialect 3 there are three kinds,
-a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the
-day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``.
-
-The problem is that the dialect of a Firebird database is a property
-of the database itself [#]_ (that is, any single database has been
-created with one dialect or the other: there is no way to change the
-after creation). SQLAlchemy has a single instance of the class that
-controls all the connections to a particular kind of database, so it
-cannot easily differentiate between the two modes, and in particular
-it **cannot** simultaneously talk with two distinct Firebird databases
-with different dialects.
-
-By default this module is biased toward dialect 3, but you can easily
-tweak it to handle dialect 1 if needed::
-
-  from sqlalchemy import types as sqltypes
-  from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names
-
-  # Adjust the mapping of the timestamp kind
-  ischema_names['TIMESTAMP'] = FBDate
-  colspecs[sqltypes.DateTime] = FBDate,
-
-Other aspects may be version-specific. You can use the ``server_version_info()`` method
-on the ``FBDialect`` class to do whatever is needed::
-
-  from sqlalchemy.databases.firebird import FBCompiler
-
-  if engine.dialect.server_version_info(connection) < (2,0):
-      # Change the name of the function ``length`` to use the UDF version
-      # instead of ``char_length``
-      FBCompiler.LENGTH_FUNCTION_NAME = 'strlen'
-
-Pooling connections
--------------------
-
-The default strategy used by SQLAlchemy to pool the database connections
-in particular cases may raise an ``OperationalError`` with a message
-`"object XYZ is in use"`. This happens on Firebird when there are two
-connections to the database, one is using, or has used, a particular table
-and the other tries to drop or alter the same table. To garantee DDL
-operations success Firebird recommend doing them as the single connected user.
-
-In case your SA application effectively needs to do DDL operations while other
-connections are active, the following setting may alleviate the problem::
-
-  from sqlalchemy import pool
-  from sqlalchemy.databases.firebird import dialect
-
-  # Force SA to use a single connection per thread
-  dialect.poolclass = pool.SingletonThreadPool
-
-RETURNING support
------------------
-
-Firebird 2.0 supports returning a result set from inserts, and 2.1 extends
-that to deletes and updates.
-
-To use this pass the column/expression list to the ``firebird_returning``
-parameter when creating the queries::
-
-  raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
-                      firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
-
-
-.. [#] Well, that is not the whole story, as the client may still ask
-       a different (lower) dialect...
-
-.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
-.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb
-"""
-
-
-import datetime, decimal, re
-
-from sqlalchemy import exc, schema, types as sqltypes, sql, util
-from sqlalchemy.engine import base, default
-
-
-_initialized_kb = False
-
-
-class FBNumeric(sqltypes.Numeric):
-    """Handle ``NUMERIC(precision,scale)`` datatype."""
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % { 'precision': self.precision,
-                                                            'scale' : self.scale }
-
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
-        if self.asdecimal:
-            return None
-        else:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-
-
-class FBFloat(sqltypes.Float):
-    """Handle ``FLOAT(precision)`` datatype."""
-
-    def get_col_spec(self):
-        if not self.precision:
-            return "FLOAT"
-        else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class FBInteger(sqltypes.Integer):
-    """Handle ``INTEGER`` datatype."""
-
-    def get_col_spec(self):
-        return "INTEGER"
-
-
-class FBSmallInteger(sqltypes.Smallinteger):
-    """Handle ``SMALLINT`` datatype."""
-
-    def get_col_spec(self):
-        return "SMALLINT"
-
-
-class FBDateTime(sqltypes.DateTime):
-    """Handle ``TIMESTAMP`` datatype."""
-
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None or isinstance(value, datetime.datetime):
-                return value
-            else:
-                return datetime.datetime(year=value.year,
-                                         month=value.month,
-                                         day=value.day)
-        return process
-
-
-class FBDate(sqltypes.DateTime):
-    """Handle ``DATE`` datatype."""
-
-    def get_col_spec(self):
-        return "DATE"
-
-
-class FBTime(sqltypes.Time):
-    """Handle ``TIME`` datatype."""
-
-    def get_col_spec(self):
-        return "TIME"
-
-
-class FBText(sqltypes.Text):
-    """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob)."""
-
-    def get_col_spec(self):
-        return "BLOB SUB_TYPE 1"
-
-
-class FBString(sqltypes.String):
-    """Handle ``VARCHAR(length)`` datatype."""
-
-    def get_col_spec(self):
-        if self.length:
-            return "VARCHAR(%(length)s)" % {'length' : self.length}
-        else:
-            return "BLOB SUB_TYPE 1"
-
-
-class FBChar(sqltypes.CHAR):
-    """Handle ``CHAR(length)`` datatype."""
-
-    def get_col_spec(self):
-        if self.length:
-            return "CHAR(%(length)s)" % {'length' : self.length}
-        else:
-            return "BLOB SUB_TYPE 1"
-
-
-class FBBinary(sqltypes.Binary):
-    """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob)."""
-
-    def get_col_spec(self):
-        return "BLOB SUB_TYPE 0"
-
-
-class FBBoolean(sqltypes.Boolean):
-    """Handle boolean values as a ``SMALLINT`` datatype."""
-
-    def get_col_spec(self):
-        return "SMALLINT"
-
-
-colspecs = {
-    sqltypes.Integer : FBInteger,
-    sqltypes.Smallinteger : FBSmallInteger,
-    sqltypes.Numeric : FBNumeric,
-    sqltypes.Float : FBFloat,
-    sqltypes.DateTime : FBDateTime,
-    sqltypes.Date : FBDate,
-    sqltypes.Time : FBTime,
-    sqltypes.String : FBString,
-    sqltypes.Binary : FBBinary,
-    sqltypes.Boolean : FBBoolean,
-    sqltypes.Text : FBText,
-    sqltypes.CHAR: FBChar,
-}
-
-
-ischema_names = {
-      'SHORT': lambda r: FBSmallInteger(),
-       'LONG': lambda r: FBInteger(),
-       'QUAD': lambda r: FBFloat(),
-      'FLOAT': lambda r: FBFloat(),
-       'DATE': lambda r: FBDate(),
-       'TIME': lambda r: FBTime(),
-       'TEXT': lambda r: FBString(r['flen']),
-      'INT64': lambda r: FBNumeric(precision=r['fprec'], scale=r['fscale'] * -1), # This generically handles NUMERIC()
-     'DOUBLE': lambda r: FBFloat(),
-  'TIMESTAMP': lambda r: FBDateTime(),
-    'VARYING': lambda r: FBString(r['flen']),
-    'CSTRING': lambda r: FBChar(r['flen']),
-       'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary()
-      }
-
-RETURNING_KW_NAME = 'firebird_returning'
-
-class FBExecutionContext(default.DefaultExecutionContext):
-    pass
-
-
-class FBDialect(default.DefaultDialect):
-    """Firebird dialect"""
-    name = 'firebird'
-    supports_sane_rowcount = False
-    supports_sane_multi_rowcount = False
-    max_identifier_length = 31
-    preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
-
-    def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-
-        self.type_conv = type_conv
-        self.concurrency_level = concurrency_level
-
-    def dbapi(cls):
-        import kinterbasdb
-        return kinterbasdb
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        if opts.get('port'):
-            opts['host'] = "%s/%s" % (opts['host'], opts['port'])
-            del opts['port']
-        opts.update(url.query)
-
-        type_conv = opts.pop('type_conv', self.type_conv)
-        concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
-        global _initialized_kb
-        if not _initialized_kb and self.dbapi is not None:
-            _initialized_kb = True
-            self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
-        return ([], opts)
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
-    def server_version_info(self, connection):
-        """Get the version of the Firebird server used by a connection.
-
-        Returns a tuple of (`major`, `minor`, `build`), three integers
-        representing the version of the attached server.
-        """
-
-        # This is the simpler approach (the other uses the services api),
-        # that for backward compatibility reasons returns a string like
-        #   LI-V6.3.3.12981 Firebird 2.0
-        # where the first version is a fake one resembling the old
-        # Interbase signature. This is more than enough for our purposes,
-        # as this is mainly (only?) used by the testsuite.
-
-        from re import match
-
-        fbconn = connection.connection.connection
-        version = fbconn.server_version
-        m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
-        if not m:
-            raise AssertionError("Could not determine version from string '%s'" % version)
-        return tuple([int(x) for x in m.group(5, 6, 4)])
-
-    def _normalize_name(self, name):
-        """Convert the name to lowercase if it is possible"""
-
-        # Remove trailing spaces: FB uses a CHAR() type,
-        # that is padded with spaces
-        name = name and name.rstrip()
-        if name is None:
-            return None
-        elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()):
-            return name.lower()
-        else:
-            return name
-
-    def _denormalize_name(self, name):
-        """Revert a *normalized* name to its uppercase equivalent"""
-
-        if name is None:
-            return None
-        elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
-            return name.upper()
-        else:
-            return name
-
-    def table_names(self, connection, schema):
-        """Return a list of *normalized* table names omitting system relations."""
-
-        s = """
-        SELECT r.rdb$relation_name
-        FROM rdb$relations r
-        WHERE r.rdb$system_flag=0
-        """
-        return [self._normalize_name(row[0]) for row in connection.execute(s)]
-
-    def has_table(self, connection, table_name, schema=None):
-        """Return ``True`` if the given table exists, ignoring the `schema`."""
-
-        tblqry = """
-        SELECT 1 FROM rdb$database
-        WHERE EXISTS (SELECT rdb$relation_name
-                      FROM rdb$relations
-                      WHERE rdb$relation_name=?)
-        """
-        c = connection.execute(tblqry, [self._denormalize_name(table_name)])
-        row = c.fetchone()
-        if row is not None:
-            return True
-        else:
-            return False
-
-    def has_sequence(self, connection, sequence_name):
-        """Return ``True`` if the given sequence (generator) exists."""
-
-        genqry = """
-        SELECT 1 FROM rdb$database
-        WHERE EXISTS (SELECT rdb$generator_name
-                      FROM rdb$generators
-                      WHERE rdb$generator_name=?)
-        """
-        c = connection.execute(genqry, [self._denormalize_name(sequence_name)])
-        row = c.fetchone()
-        if row is not None:
-            return True
-        else:
-            return False
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.OperationalError):
-            return 'Unable to complete network request to host' in str(e)
-        elif isinstance(e, self.dbapi.ProgrammingError):
-            msg = str(e)
-            return ('Invalid connection state' in msg or
-                    'Invalid cursor state' in msg)
-        else:
-            return False
-
-    def reflecttable(self, connection, table, include_columns):
-        # Query to extract the details of all the fields of the given table
-        tblqry = """
-        SELECT DISTINCT r.rdb$field_name AS fname,
-                        r.rdb$null_flag AS null_flag,
-                        t.rdb$type_name AS ftype,
-                        f.rdb$field_sub_type AS stype,
-                        f.rdb$field_length AS flen,
-                        f.rdb$field_precision AS fprec,
-                        f.rdb$field_scale AS fscale,
-                        COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
-        FROM rdb$relation_fields r
-             JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
-             JOIN rdb$types t ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
-        WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
-        ORDER BY r.rdb$field_position
-        """
-        # Query to extract the PK/FK constrained fields of the given table
-        keyqry = """
-        SELECT se.rdb$field_name AS fname
-        FROM rdb$relation_constraints rc
-             JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
-        WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
-        """
-        # Query to extract the details of each UK/FK of the given table
-        fkqry = """
-        SELECT rc.rdb$constraint_name AS cname,
-               cse.rdb$field_name AS fname,
-               ix2.rdb$relation_name AS targetrname,
-               se.rdb$field_name AS targetfname
-        FROM rdb$relation_constraints rc
-             JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
-             JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
-             JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
-             JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position
-        WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
-        ORDER BY se.rdb$index_name, se.rdb$field_position
-        """
-        # Heuristic-query to determine the generator associated to a PK field
-        genqry = """
-        SELECT trigdep.rdb$depended_on_name AS fgenerator
-        FROM rdb$dependencies tabdep
-             JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
-                                               AND trigdep.rdb$depended_on_type=14
-                                               AND trigdep.rdb$dependent_type=2)
-             JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name)
-        WHERE tabdep.rdb$depended_on_name=?
-          AND tabdep.rdb$depended_on_type=0
-          AND trig.rdb$trigger_type=1
-          AND tabdep.rdb$field_name=?
-          AND (SELECT count(*)
-               FROM rdb$dependencies trigdep2
-               WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
-        """
-
-        tablename = self._denormalize_name(table.name)
-
-        # get primary key fields
-        c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
-        pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
-
-        # get all of the fields for this table
-        c = connection.execute(tblqry, [tablename])
-
-        found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            found_table = True
-
-            name = self._normalize_name(row['fname'])
-            if include_columns and name not in include_columns:
-                continue
-            args = [name]
-
-            kw = {}
-            # get the data type
-            coltype = ischema_names.get(row['ftype'].rstrip())
-            if coltype is None:
-                util.warn("Did not recognize type '%s' of column '%s'" %
-                          (str(row['ftype']), name))
-                coltype = sqltypes.NULLTYPE
-            else:
-                coltype = coltype(row)
-            args.append(coltype)
-
-            # is it a primary key?
-            kw['primary_key'] = name in pkfields
-
-            # is it nullable?
-            kw['nullable'] = not bool(row['null_flag'])
-
-            # does it have a default value?
-            if row['fdefault'] is not None:
-                # the value comes down as "DEFAULT 'value'"
-                assert row['fdefault'].upper().startswith('DEFAULT '), row
-                defvalue = row['fdefault'][8:]
-                args.append(schema.DefaultClause(sql.text(defvalue)))
-
-            col = schema.Column(*args, **kw)
-            if kw['primary_key']:
-                # if the PK is a single field, try to see if its linked to
-                # a sequence thru a trigger
-                if len(pkfields)==1:
-                    genc = connection.execute(genqry, [tablename, row['fname']])
-                    genr = genc.fetchone()
-                    if genr is not None:
-                        col.sequence = schema.Sequence(self._normalize_name(genr['fgenerator']))
-
-            table.append_column(col)
-
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
-
-        # get the foreign keys
-        c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
-        fks = {}
-        while True:
-            row = c.fetchone()
-            if not row:
-                break
-
-            cname = self._normalize_name(row['cname'])
-            try:
-                fk = fks[cname]
-            except KeyError:
-                fks[cname] = fk = ([], [])
-            rname = self._normalize_name(row['targetrname'])
-            schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
-            fname = self._normalize_name(row['fname'])
-            refspec = rname + '.' + self._normalize_name(row['targetfname'])
-            fk[0].append(fname)
-            fk[1].append(refspec)
-
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True))
-
-    def do_execute(self, cursor, statement, parameters, **kwargs):
-        # kinterbase does not accept a None, but wants an empty list
-        # when there are no arguments.
-        cursor.execute(statement, parameters or [])
-
-    def do_rollback(self, connection):
-        # Use the retaining feature, that keeps the transaction going
-        connection.rollback(True)
-
-    def do_commit(self, connection):
-        # Use the retaining feature, that keeps the transaction going
-        connection.commit(True)
-
-
-def _substring(s, start, length=None):
-    "Helper function to handle Firebird 2 SUBSTRING builtin"
-
-    if length is None:
-        return "SUBSTRING(%s FROM %s)" % (s, start)
-    else:
-        return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
-
-
-class FBCompiler(sql.compiler.DefaultCompiler):
-    """Firebird specific idiosincrasies"""
-
-    # Firebird lacks a builtin modulo operator, but there is
-    # an equivalent function in the ib_udf library.
-    operators = sql.compiler.DefaultCompiler.operators.copy()
-    operators.update({
-        sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
-        })
-
-    def visit_alias(self, alias, asfrom=False, **kwargs):
-        # Override to not use the AS keyword which FB 1.5 does not like
-        if asfrom:
-            return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
-        else:
-            return self.process(alias.original, **kwargs)
-
-    functions = sql.compiler.DefaultCompiler.functions.copy()
-    functions['substring'] = _substring
-
-    def function_argspec(self, func):
-        if func.clauses:
-            return self.process(func.clause_expr)
-        else:
-            return ""
-
-    def default_from(self):
-        return " FROM rdb$database"
-
-    def visit_sequence(self, seq):
-        return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
-
-    def get_select_precolumns(self, select):
-        """Called when building a ``SELECT`` statement, position is just
-        before column list Firebird puts the limit and offset right
-        after the ``SELECT``...
-        """
-
-        result = ""
-        if select._limit:
-            result += "FIRST %d "  % select._limit
-        if select._offset:
-            result +="SKIP %d "  %  select._offset
-        if select._distinct:
-            result += "DISTINCT "
-        return result
-
-    def limit_clause(self, select):
-        """Already taken care of in the `get_select_precolumns` method."""
-
-        return ""
-
-    LENGTH_FUNCTION_NAME = 'char_length'
-    def function_string(self, func):
-        """Substitute the ``length`` function.
-
-        On newer FB there is a ``char_length`` function, while older
-        ones need the ``strlen`` UDF.
-        """
-
-        if func.name == 'length':
-            return self.LENGTH_FUNCTION_NAME + '%(expr)s'
-        return super(FBCompiler, self).function_string(func)
-
-    def _append_returning(self, text, stmt):
-        returning_cols = stmt.kwargs[RETURNING_KW_NAME]
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, sql.expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-        columns = [self.process(c, within_columns_clause=True)
-                   for c in flatten_columnlist(returning_cols)]
-        text += ' RETURNING ' + ', '.join(columns)
-        return text
-
-    def visit_update(self, update_stmt):
-        text = super(FBCompiler, self).visit_update(update_stmt)
-        if RETURNING_KW_NAME in update_stmt.kwargs:
-            return self._append_returning(text, update_stmt)
-        else:
-            return text
-
-    def visit_insert(self, insert_stmt):
-        text = super(FBCompiler, self).visit_insert(insert_stmt)
-        if RETURNING_KW_NAME in insert_stmt.kwargs:
-            return self._append_returning(text, insert_stmt)
-        else:
-            return text
-
-    def visit_delete(self, delete_stmt):
-        text = super(FBCompiler, self).visit_delete(delete_stmt)
-        if RETURNING_KW_NAME in delete_stmt.kwargs:
-            return self._append_returning(text, delete_stmt)
-        else:
-            return text
-
-
-class FBSchemaGenerator(sql.compiler.SchemaGenerator):
-    """Firebird syntactic idiosincrasies"""
-
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable or column.primary_key:
-            colspec += " NOT NULL"
-
-        return colspec
-
-    def visit_sequence(self, sequence):
-        """Generate a ``CREATE GENERATOR`` statement for the sequence."""
-
-        if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
-            self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-
-class FBSchemaDropper(sql.compiler.SchemaDropper):
-    """Firebird syntactic idiosincrasies"""
-
-    def visit_sequence(self, sequence):
-        """Generate a ``DROP GENERATOR`` statement for the sequence."""
-
-        if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
-            self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-
-class FBDefaultRunner(base.DefaultRunner):
-    """Firebird specific idiosincrasies"""
-
-    def visit_sequence(self, seq):
-        """Get the next value from the sequence using ``gen_id()``."""
-
-        return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
-            self.dialect.identifier_preparer.format_sequence(seq))
-
-
-RESERVED_WORDS = set(
-    ["action", "active", "add", "admin", "after", "all", "alter", "and", "any",
-     "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename",
-     "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer",
-     "by", "cache", "cascade", "case", "cast", "char", "character", "character_length",
-     "char_length", "check", "check_point_len", "check_point_length", "close", "collate",
-     "collation", "column", "commit", "committed", "compiletime", "computed", "conditional",
-     "connect", "constraint", "containing", "continue", "count", "create", "cstring",
-     "current", "current_connection", "current_date", "current_role", "current_time",
-     "current_timestamp", "current_transaction", "current_user", "cursor", "database",
-     "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete",
-     "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct",
-     "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point",
-     "escape", "event", "exception", "execute", "exists", "exit", "extern", "external",
-     "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it",
-     "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto",
-     "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour",
-     "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input",
-     "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join",
-     "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile",
-     "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment",
-     "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month",
-     "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric",
-     "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option",
-     "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength",
-     "pages", "page_size", "parameter", "password", "plan", "position", "post_event",
-     "precision", "prepare", "primary", "privileges", "procedure", "protected", "public",
-     "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate",
-     "references", "release", "release", "reserv", "reserving", "restrict", "retain",
-     "return", "returning_values", "returns", "revoke", "right", "role", "rollback",
-     "row_count", "runtime", "savepoint", "schema", "second", "segment", "select",
-     "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint",
-     "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability",
-     "starting", "starts", "statement", "static", "statistics", "sub_type", "sum",
-     "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction",
-     "translate", "translation", "trigger", "trim", "type", "uncommitted", "union",
-     "unique", "update", "upper", "user", "using", "value", "values", "varchar",
-     "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
-     "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
-
-
-class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
-    """Install Firebird specific reserved words."""
-
-    reserved_words = RESERVED_WORDS
-
-    def __init__(self, dialect):
-        super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
-
-
-dialect = FBDialect
-dialect.statement_compiler = FBCompiler
-dialect.schemagenerator = FBSchemaGenerator
-dialect.schemadropper = FBSchemaDropper
-dialect.defaultrunner = FBDefaultRunner
-dialect.preparer = FBIdentifierPreparer
-dialect.execution_ctx_cls = FBExecutionContext
diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py
deleted file mode 100644 (file)
index a7d4101..0000000
+++ /dev/null
@@ -1,193 +0,0 @@
-"""
-information schema implementation.
-
-This module is deprecated and will not be present in this form in SQLAlchemy 0.6.
-
-"""
-from sqlalchemy import util
-
-util.warn_deprecated("the information_schema module is deprecated.")
-
-import sqlalchemy.sql as sql
-import sqlalchemy.exc as exc
-from sqlalchemy import select, MetaData, Table, Column, String, Integer
-from sqlalchemy.schema import DefaultClause, ForeignKeyConstraint
-
-ischema = MetaData()
-
-schemata = Table("schemata", ischema,
-    Column("catalog_name", String),
-    Column("schema_name", String),
-    Column("schema_owner", String),
-    schema="information_schema")
-
-tables = Table("tables", ischema,
-    Column("table_catalog", String),
-    Column("table_schema", String),
-    Column("table_name", String),
-    Column("table_type", String),
-    schema="information_schema")
-
-columns = Table("columns", ischema,
-    Column("table_schema", String),
-    Column("table_name", String),
-    Column("column_name", String),
-    Column("is_nullable", Integer),
-    Column("data_type", String),
-    Column("ordinal_position", Integer),
-    Column("character_maximum_length", Integer),
-    Column("numeric_precision", Integer),
-    Column("numeric_scale", Integer),
-    Column("column_default", Integer),
-    Column("collation_name", String),
-    schema="information_schema")
-
-constraints = Table("table_constraints", ischema,
-    Column("table_schema", String),
-    Column("table_name", String),
-    Column("constraint_name", String),
-    Column("constraint_type", String),
-    schema="information_schema")
-
-column_constraints = Table("constraint_column_usage", ischema,
-    Column("table_schema", String),
-    Column("table_name", String),
-    Column("column_name", String),
-    Column("constraint_name", String),
-    schema="information_schema")
-
-pg_key_constraints = Table("key_column_usage", ischema,
-    Column("table_schema", String),
-    Column("table_name", String),
-    Column("column_name", String),
-    Column("constraint_name", String),
-    Column("ordinal_position", Integer),
-    schema="information_schema")
-
-#mysql_key_constraints = Table("key_column_usage", ischema,
-#    Column("table_schema", String),
-#    Column("table_name", String),
-#    Column("column_name", String),
-#    Column("constraint_name", String),
-#    Column("referenced_table_schema", String),
-#    Column("referenced_table_name", String),
-#    Column("referenced_column_name", String),
-#    schema="information_schema")
-
-key_constraints = pg_key_constraints
-
-ref_constraints = Table("referential_constraints", ischema,
-    Column("constraint_catalog", String),
-    Column("constraint_schema", String),
-    Column("constraint_name", String),
-    Column("unique_constraint_catlog", String),
-    Column("unique_constraint_schema", String),
-    Column("unique_constraint_name", String),
-    Column("match_option", String),
-    Column("update_rule", String),
-    Column("delete_rule", String),
-    schema="information_schema")
-
-
-def table_names(connection, schema):
-    s = select([tables.c.table_name], tables.c.table_schema==schema)
-    return [row[0] for row in connection.execute(s)]
-
-
-def reflecttable(connection, table, include_columns, ischema_names):
-    key_constraints = pg_key_constraints
-
-    if table.schema is not None:
-        current_schema = table.schema
-    else:
-        current_schema = connection.default_schema_name()
-
-    s = select([columns],
-        sql.and_(columns.c.table_name==table.name,
-        columns.c.table_schema==current_schema),
-        order_by=[columns.c.ordinal_position])
-
-    c = connection.execute(s)
-    found_table = False
-    while True:
-        row = c.fetchone()
-        if row is None:
-            break
-        #print "row! " + repr(row)
- #       continue
-        found_table = True
-        (name, type, nullable, charlen, numericprec, numericscale, default) = (
-            row[columns.c.column_name],
-            row[columns.c.data_type],
-            row[columns.c.is_nullable] == 'YES',
-            row[columns.c.character_maximum_length],
-            row[columns.c.numeric_precision],
-            row[columns.c.numeric_scale],
-            row[columns.c.column_default]
-            )
-        if include_columns and name not in include_columns:
-            continue
-
-        args = []
-        for a in (charlen, numericprec, numericscale):
-            if a is not None:
-                args.append(a)
-        coltype = ischema_names[type]
-        #print "coltype " + repr(coltype) + " args " +  repr(args)
-        coltype = coltype(*args)
-        colargs = []
-        if default is not None:
-            colargs.append(DefaultClause(sql.text(default)))
-        table.append_column(Column(name, coltype, nullable=nullable, *colargs))
-
-    if not found_table:
-        raise exc.NoSuchTableError(table.name)
-
-    # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns
-    # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys
-    # wont reflect properly.  dont see a way around this based on whats available from information_schema
-    s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position])
-    s.append_column(column_constraints)
-    s.append_whereclause(constraints.c.table_name==table.name)
-    s.append_whereclause(constraints.c.table_schema==current_schema)
-    colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position]
-    c = connection.execute(s)
-
-    fks = {}
-    while True:
-        row = c.fetchone()
-        if row is None:
-            break
-        (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = (
-            row[colmap[0]],
-            row[colmap[1]],
-            row[colmap[2]],
-            row[colmap[3]],
-            row[colmap[4]],
-            row[colmap[5]],
-            row[colmap[6]]
-        )
-        #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column)
-        if type == 'PRIMARY KEY':
-            table.primary_key.add(table.c[constrained_column])
-        elif type == 'FOREIGN KEY':
-            try:
-                fk = fks[constraint_name]
-            except KeyError:
-                fk = ([], [])
-                fks[constraint_name] = fk
-            if current_schema == referred_schema:
-                referred_schema = table.schema
-            if referred_schema is not None:
-                Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection)
-                refspec = ".".join([referred_schema, referred_table, referred_column])
-            else:
-                Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
-                refspec = ".".join([referred_table, referred_column])
-            if constrained_column not in fk[0]:
-                fk[0].append(constrained_column)
-            if refspec not in fk[1]:
-                fk[1].append(refspec)
-
-    for name, value in fks.iteritems():
-        table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name))
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
deleted file mode 100644 (file)
index d963b74..0000000
+++ /dev/null
@@ -1,1771 +0,0 @@
-# mssql.py
-
-"""Support for the Microsoft SQL Server database.
-
-Driver
-------
-
-The MSSQL dialect will work with three different available drivers:
-
-* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
-  driver.
-
-* *pymssql* - http://pymssql.sourceforge.net/
-
-* *adodbapi* - http://adodbapi.sourceforge.net/
-
-Drivers are loaded in the order listed above based on availability.
-
-If you need to load a specific driver pass ``module_name`` when
-creating the engine::
-
-    engine = create_engine('mssql://dsn', module_name='pymssql')
-
-``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
-``adodbapi``.
-
-Currently the pyodbc driver offers the greatest level of
-compatibility.
-
-Connecting
-----------
-
-Connecting with create_engine() uses the standard URL approach of
-``mssql://user:pass@host/dbname[?key=value&key=value...]``.
-
-If the database name is present, the tokens are converted to a
-connection string with the specified values. If the database is not
-present, then the host token is taken directly as the DSN name.
-
-Examples of pyodbc connection string URLs:
-
-* *mssql://mydsn* - connects using the specified DSN named ``mydsn``.
-  The connection string that is created will appear like::
-
-    dsn=mydsn;TrustedConnection=Yes
-
-* *mssql://user:pass@mydsn* - connects using the DSN named
-  ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
-  connection string that is created will appear like::
-
-    dsn=mydsn;UID=user;PWD=pass
-
-* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects
-  using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
-  information, plus the additional connection configuration option
-  ``LANGUAGE``. The connection string that is created will appear
-  like::
-
-    dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
-
-* *mssql://user:pass@host/db* - connects using a connection string
-  dynamically created that would appear like::
-
-    DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
-
-* *mssql://user:pass@host:123/db* - connects using a connection
-  string that is dynamically created, which also includes the port
-  information using the comma syntax. If your connection string
-  requires the port information to be passed as a ``port`` keyword
-  see the next example. This will create the following connection
-  string::
-
-    DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
-
-* *mssql://user:pass@host/db?port=123* - connects using a connection
-  string that is dynamically created that includes the port
-  information as a separate ``port`` keyword. This will create the
-  following connection string::
-
-    DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
-
-If you require a connection string that is outside the options
-presented above, use the ``odbc_connect`` keyword to pass in a
-urlencoded connection string. What gets passed in will be urldecoded
-and passed directly.
-
-For example::
-
-    mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
-
-would create the following connection string::
-
-    dsn=mydsn;Database=db
-
-Encoding your connection string can be easily accomplished through
-the python shell. For example::
-
-    >>> import urllib
-    >>> urllib.quote_plus('dsn=mydsn;Database=db')
-    'dsn%3Dmydsn%3BDatabase%3Ddb'
-
-Additional arguments which may be specified either as query string
-arguments on the URL, or as keyword argument to
-:func:`~sqlalchemy.create_engine()` are:
-
-* *auto_identity_insert* - enables support for IDENTITY inserts by
-  automatically turning IDENTITY INSERT ON and OFF as required.
-  Defaults to ``True``.
-
-* *query_timeout* - allows you to override the default query timeout.
-  Defaults to ``None``. This is only supported on pymssql.
-
-* *text_as_varchar* - if enabled this will treat all TEXT column
-  types as their equivalent VARCHAR(max) type. This is often used if
-  you need to compare a VARCHAR to a TEXT field, which is not
-  supported directly on MSSQL. Defaults to ``False``.
-
-* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
-  should be used in place of the non-scoped version @@IDENTITY.
-  Defaults to ``False``. On pymssql this defaults to ``True``, and on
-  pyodbc this defaults to ``True`` if the version of pyodbc being
-  used supports it.
-
-* *has_window_funcs* - indicates whether or not window functions
-  (LIMIT and OFFSET) are supported on the version of MSSQL being
-  used. If you're running MSSQL 2005 or later turn this on to get
-  OFFSET support. Defaults to ``False``.
-
-* *max_identifier_length* - allows you to se the maximum length of
-  identfiers supported by the database. Defaults to 128. For pymssql
-  the default is 30.
-
-* *schema_name* - use to set the schema name. Defaults to ``dbo``.
-
-Auto Increment Behavior
------------------------
-
-``IDENTITY`` columns are supported by using SQLAlchemy
-``schema.Sequence()`` objects. In other words::
-
-    Table('test', mss_engine,
-           Column('id', Integer,
-                  Sequence('blah',100,10), primary_key=True),
-           Column('name', String(20))
-         ).create()
-
-would yield::
-
-   CREATE TABLE test (
-     id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
-     name VARCHAR(20) NULL,
-     )
-
-Note that the ``start`` and ``increment`` values for sequences are
-optional and will default to 1,1.
-
-* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
-  ``INSERT`` s)
-
-* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on
-  ``INSERT``
-
-Collation Support
------------------
-
-MSSQL specific string types support a collation parameter that
-creates a column-level specific collation for the column. The
-collation parameter accepts a Windows Collation Name or a SQL
-Collation Name. Supported types are MSChar, MSNChar, MSString,
-MSNVarchar, MSText, and MSNText. For example::
-
-    Column('login', String(32, collation='Latin1_General_CI_AS'))
-
-will yield::
-
-    login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
-
-LIMIT/OFFSET Support
---------------------
-
-MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is
-supported directly through the ``TOP`` Transact SQL keyword::
-
-    select.limit
-
-will yield::
-
-    SELECT TOP n
-
-If the ``has_window_funcs`` flag is set then LIMIT with OFFSET
-support is available through the ``ROW_NUMBER OVER`` construct. This
-construct requires an ``ORDER BY`` to be specified as well and is
-only available on MSSQL 2005 and later.
-
-Nullability
------------
-MSSQL has support for three levels of column nullability. The default
-nullability allows nulls and is explicit in the CREATE TABLE
-construct::
-
-    name VARCHAR(20) NULL
-
-If ``nullable=None`` is specified then no specification is made. In
-other words the database's configured default is used. This will
-render::
-
-    name VARCHAR(20)
-
-If ``nullable`` is ``True`` or ``False`` then the column will be
-``NULL` or ``NOT NULL`` respectively.
-
-Date / Time Handling
---------------------
-For MSSQL versions that support the ``DATE`` and ``TIME`` types
-(MSSQL 2008+) the data type is used. For versions that do not
-support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used
-instead and the MSSQL dialect handles converting the results
-properly. This means ``Date()`` and ``Time()`` are fully supported
-on all versions of MSSQL. If you do not desire this behavior then
-do not use the ``Date()`` or ``Time()`` types.
-
-Compatibility Levels
---------------------
-MSSQL supports the notion of setting compatibility levels at the
-database level. This allows, for instance, to run a database that
-is compatibile with SQL2000 while running on a SQL2005 database
-server. ``server_version_info`` will always retrun the database
-server version information (in this case SQL2005) and not the
-compatibiility level information. Because of this, if running under
-a backwards compatibility mode SQAlchemy may attempt to use T-SQL
-statements that are unable to be parsed by the database server.
-
-Known Issues
-------------
-
-* No support for more than one ``IDENTITY`` column per table
-
-* pymssql has problems with binary and unicode data that this module
-  does **not** work around
-
-"""
-import datetime, decimal, inspect, operator, re, sys, urllib
-
-from sqlalchemy import sql, schema, exc, util
-from sqlalchemy import Table, MetaData, Column, ForeignKey, String, Integer
-from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions
-from sqlalchemy.engine import default, base
-from sqlalchemy import types as sqltypes
-from decimal import Decimal as _python_Decimal
-
-
-RESERVED_WORDS = set(
-    ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization',
-     'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade',
-     'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce',
-     'collate', 'column', 'commit', 'compute', 'constraint', 'contains',
-     'containstable', 'continue', 'convert', 'create', 'cross', 'current',
-     'current_date', 'current_time', 'current_timestamp', 'current_user',
-     'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default',
-     'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double',
-     'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec',
-     'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor',
-     'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full',
-     'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity',
-     'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert',
-     'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like',
-     'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not',
-     'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource',
-     'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer',
-     'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print',
-     'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext',
-     'reconfigure', 'references', 'replication', 'restore', 'restrict',
-     'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount',
-     'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select',
-     'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics',
-     'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top',
-     'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union',
-     'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
-     'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
-     'writetext',
-    ])
-
-
-class _StringType(object):
-    """Base for MSSQL string types."""
-
-    def __init__(self, collation=None, **kwargs):
-        self.collation = kwargs.get('collate', collation)
-
-    def _extend(self, spec):
-        """Extend a string-type declaration with standard SQL
-        COLLATE annotations.
-        """
-
-        if self.collation:
-            collation = 'COLLATE %s' % self.collation
-        else:
-            collation = None
-
-        return ' '.join([c for c in (spec, collation)
-                         if c is not None])
-
-    def __repr__(self):
-        attributes = inspect.getargspec(self.__init__)[0][1:]
-        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
-
-        params = {}
-        for attr in attributes:
-            val = getattr(self, attr)
-            if val is not None and val is not False:
-                params[attr] = val
-
-        return "%s(%s)" % (self.__class__.__name__,
-                           ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-
-    def bind_processor(self, dialect):
-        if self.convert_unicode or dialect.convert_unicode:
-            if self.assert_unicode is None:
-                assert_unicode = dialect.assert_unicode
-            else:
-                assert_unicode = self.assert_unicode
-
-            if not assert_unicode:
-                return None
-
-            def process(value):
-                if not isinstance(value, (unicode, sqltypes.NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
-                        return value
-                    else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
-            return process
-        else:
-            return None
-
-
-class MSNumeric(sqltypes.Numeric):
-    def result_processor(self, dialect):
-        if self.asdecimal:
-            def process(value):
-                if value is not None:
-                    return _python_Decimal(str(value))
-                else:
-                    return value
-            return process
-        else:
-            def process(value):
-                return float(value)
-            return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                # Not sure that this exception is needed
-                return value
-
-            elif isinstance(value, decimal.Decimal):
-                if value.adjusted() < 0:
-                    result = "%s0.%s%s" % (
-                            (value < 0 and '-' or ''),
-                            '0' * (abs(value.adjusted()) - 1),
-                            "".join([str(nint) for nint in value._int]))
-
-                else:
-                    if 'E' in str(value):
-                        result = "%s%s%s" % (
-                                (value < 0 and '-' or ''),
-                                "".join([str(s) for s in value._int]),
-                                "0" * (value.adjusted() - (len(value._int)-1)))
-                    else:
-                        if (len(value._int) - 1) > value.adjusted():
-                            result = "%s%s.%s" % (
-                                    (value < 0 and '-' or ''),
-                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
-                                    "".join([str(s) for s in value._int][value.adjusted() + 1:]))
-                        else:
-                            result = "%s%s" % (
-                                    (value < 0 and '-' or ''),
-                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
-
-                return result
-
-            else:
-                return value
-
-        return process
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-
-class MSFloat(sqltypes.Float):
-    def get_col_spec(self):
-        if self.precision is None:
-            return "FLOAT"
-        else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class MSReal(MSFloat):
-    """A type for ``real`` numbers."""
-
-    def __init__(self):
-        """
-        Construct a Real.
-
-        """
-        super(MSReal, self).__init__(precision=24)
-
-    def adapt(self, impltype):
-        return impltype()
-
-    def get_col_spec(self):
-        return "REAL"
-
-
-class MSInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-
-class MSBigInteger(MSInteger):
-    def get_col_spec(self):
-        return "BIGINT"
-
-
-class MSTinyInteger(MSInteger):
-    def get_col_spec(self):
-        return "TINYINT"
-
-
-class MSSmallInteger(MSInteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-
-class _DateTimeType(object):
-    """Base for MSSQL datetime types."""
-
-    def bind_processor(self, dialect):
-        # if we receive just a date we can manipulate it
-        # into a datetime since the db-api may not do this.
-        def process(value):
-            if type(value) is datetime.date:
-                return datetime.datetime(value.year, value.month, value.day)
-            return value
-        return process
-
-
-class MSDateTime(_DateTimeType, sqltypes.DateTime):
-    def get_col_spec(self):
-        return "DATETIME"
-
-
-class MSDate(sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
-
-
-class MSTime(sqltypes.Time):
-    def __init__(self, precision=None, **kwargs):
-        self.precision = precision
-        super(MSTime, self).__init__()
-
-    def get_col_spec(self):
-        if self.precision:
-            return "TIME(%s)" % self.precision
-        else:
-            return "TIME"
-
-
-class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "SMALLDATETIME"
-
-
-class MSDateTime2(_DateTimeType, sqltypes.TypeEngine):
-    def __init__(self, precision=None, **kwargs):
-        self.precision = precision
-
-    def get_col_spec(self):
-        if self.precision:
-            return "DATETIME2(%s)" % self.precision
-        else:
-            return "DATETIME2"
-
-
-class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine):
-    def __init__(self, precision=None, **kwargs):
-        self.precision = precision
-
-    def get_col_spec(self):
-        if self.precision:
-            return "DATETIMEOFFSET(%s)" % self.precision
-        else:
-            return "DATETIMEOFFSET"
-
-
-class MSDateTimeAsDate(_DateTimeType, MSDate):
-    """ This is an implementation of the Date type for versions of MSSQL that
-    do not support that specific type. In order to make it work a ``DATETIME``
-    column specification is used and the results get converted back to just
-    the date portion.
-
-    """
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-    def result_processor(self, dialect):
-        def process(value):
-            # If the DBAPI returns the value as datetime.datetime(), truncate
-            # it back to datetime.date()
-            if type(value) is datetime.datetime:
-                return value.date()
-            return value
-        return process
-
-
-class MSDateTimeAsTime(MSTime):
-    """ This is an implementation of the Time type for versions of MSSQL that
-    do not support that specific type. In order to make it work a ``DATETIME``
-    column specification is used and the results get converted back to just
-    the time portion.
-
-    """
-
-    __zero_date = datetime.date(1900, 1, 1)
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if type(value) is datetime.datetime:
-                value = datetime.datetime.combine(self.__zero_date, value.time())
-            elif type(value) is datetime.time:
-                value = datetime.datetime.combine(self.__zero_date, value)
-            return value
-        return process
-
-    def result_processor(self, dialect):
-        def process(value):
-            if type(value) is datetime.datetime:
-                return value.time()
-            elif type(value) is datetime.date:
-                return datetime.time(0, 0, 0)
-            return value
-        return process
-
-
-class MSDateTime_adodbapi(MSDateTime):
-    def result_processor(self, dialect):
-        def process(value):
-            # adodbapi will return datetimes with empty time values as datetime.date() objects.
-            # Promote them back to full datetime.datetime()
-            if type(value) is datetime.date:
-                return datetime.datetime(value.year, value.month, value.day)
-            return value
-        return process
-
-
-class MSText(_StringType, sqltypes.Text):
-    """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
-
-    def __init__(self, *args, **kwargs):
-        """Construct a TEXT.
-
-        :param collation: Optional, a column-level collation for this string
-          value. Accepts a Windows Collation Name or a SQL Collation Name.
-
-        """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.Text.__init__(self, None,
-                convert_unicode=kwargs.get('convert_unicode', False),
-                assert_unicode=kwargs.get('assert_unicode', None))
-
-    def get_col_spec(self):
-        if self.dialect.text_as_varchar:
-            return self._extend("VARCHAR(max)")
-        else:
-            return self._extend("TEXT")
-
-
-class MSNText(_StringType, sqltypes.UnicodeText):
-    """MSSQL NTEXT type, for variable-length unicode text up to 2^30
-    characters."""
-
-    def __init__(self, *args, **kwargs):
-        """Construct a NTEXT.
-
-        :param collation: Optional, a column-level collation for this string
-          value. Accepts a Windows Collation Name or a SQL Collation Name.
-
-        """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.UnicodeText.__init__(self, None,
-                convert_unicode=kwargs.get('convert_unicode', True),
-                assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
-    def get_col_spec(self):
-        if self.dialect.text_as_varchar:
-            return self._extend("NVARCHAR(max)")
-        else:
-            return self._extend("NTEXT")
-
-
-class MSString(_StringType, sqltypes.String):
-    """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
-    of 8,000 characters."""
-
-    def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
-        """Construct a VARCHAR.
-
-        :param length: Optinal, maximum data length, in characters.
-
-        :param convert_unicode: defaults to False.  If True, convert
-          ``unicode`` data sent to the database to a ``str``
-          bytestring, and convert bytestrings coming back from the
-          database into ``unicode``.
-
-          Bytestrings are encoded using the dialect's
-          :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
-          defaults to `utf-8`.
-
-          If False, may be overridden by
-          :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
-
-        :param assert_unicode:
-
-          If None (the default), no assertion will take place unless
-          overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
-
-          If 'warn', will issue a runtime warning if a ``str``
-          instance is used as a bind value.
-
-          If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
-
-        :param collation: Optional, a column-level collation for this string
-          value. Accepts a Windows Collation Name or a SQL Collation Name.
-
-        """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.String.__init__(self, length=length,
-                convert_unicode=convert_unicode,
-                assert_unicode=assert_unicode)
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("VARCHAR(%s)" % self.length)
-        else:
-            return self._extend("VARCHAR")
-
-
-class MSNVarchar(_StringType, sqltypes.Unicode):
-    """MSSQL NVARCHAR type.
-
-    For variable-length unicode character data up to 4,000 characters."""
-
-    def __init__(self, length=None, **kwargs):
-        """Construct a NVARCHAR.
-
-        :param length: Optional, Maximum data length, in characters.
-
-        :param collation: Optional, a column-level collation for this string
-          value. Accepts a Windows Collation Name or a SQL Collation Name.
-
-        """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.Unicode.__init__(self, length=length,
-                convert_unicode=kwargs.get('convert_unicode', True),
-                assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
-    def adapt(self, impltype):
-        return impltype(length=self.length,
-                        convert_unicode=self.convert_unicode,
-                        assert_unicode=self.assert_unicode,
-                        collation=self.collation)
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length})
-        else:
-            return self._extend("NVARCHAR")
-
-
-class MSChar(_StringType, sqltypes.CHAR):
-    """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
-    of 8,000 characters."""
-
-    def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
-        """Construct a CHAR.
-
-        :param length: Optinal, maximum data length, in characters.
-
-        :param convert_unicode: defaults to False.  If True, convert
-          ``unicode`` data sent to the database to a ``str``
-          bytestring, and convert bytestrings coming back from the
-          database into ``unicode``.
-
-          Bytestrings are encoded using the dialect's
-          :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
-          defaults to `utf-8`.
-
-          If False, may be overridden by
-          :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
-
-        :param assert_unicode:
-
-          If None (the default), no assertion will take place unless
-          overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
-
-          If 'warn', will issue a runtime warning if a ``str``
-          instance is used as a bind value.
-
-          If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
-
-        :param collation: Optional, a column-level collation for this string
-          value. Accepts a Windows Collation Name or a SQL Collation Name.
-
-        """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.CHAR.__init__(self, length=length,
-                convert_unicode=convert_unicode,
-                assert_unicode=assert_unicode)
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("CHAR(%s)" % self.length)
-        else:
-            return self._extend("CHAR")
-
-
-class MSNChar(_StringType, sqltypes.NCHAR):
-    """MSSQL NCHAR type.
-
-    For fixed-length unicode character data up to 4,000 characters."""
-
-    def __init__(self, length=None, **kwargs):
-        """Construct an NCHAR.
-
-        :param length: Optional, Maximum data length, in characters.
-
-        :param collation: Optional, a column-level collation for this string
-          value. Accepts a Windows Collation Name or a SQL Collation Name.
-
-        """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.NCHAR.__init__(self, length=length,
-                convert_unicode=kwargs.get('convert_unicode', True),
-                assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("NCHAR(%(length)s)" % {'length' : self.length})
-        else:
-            return self._extend("NCHAR")
-
-
-class MSGenericBinary(sqltypes.Binary):
-    """The Binary type assumes that a Binary specification without a length
-    is an unbound Binary type whereas one with a length specification results
-    in a fixed length Binary type.
-
-    If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type.
-
-    """
-
-    def get_col_spec(self):
-        if self.length:
-            return "BINARY(%s)" % self.length
-        else:
-            return "IMAGE"
-
-
-class MSBinary(MSGenericBinary):
-    def get_col_spec(self):
-        if self.length:
-            return "BINARY(%s)" % self.length
-        else:
-            return "BINARY"
-
-
-class MSVarBinary(MSGenericBinary):
-    def get_col_spec(self):
-        if self.length:
-            return "VARBINARY(%s)" % self.length
-        else:
-            return "VARBINARY"
-
-
-class MSImage(MSGenericBinary):
-    def get_col_spec(self):
-        return "IMAGE"
-
-
-class MSBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BIT"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is True:
-                return 1
-            elif value is False:
-                return 0
-            elif value is None:
-                return None
-            else:
-                return value and True or False
-        return process
-
-
-class MSTimeStamp(sqltypes.TIMESTAMP):
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-
-class MSMoney(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "MONEY"
-
-
-class MSSmallMoney(MSMoney):
-    def get_col_spec(self):
-        return "SMALLMONEY"
-
-
-class MSUniqueIdentifier(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "UNIQUEIDENTIFIER"
-
-
-class MSVariant(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "SQL_VARIANT"
-
-ischema = MetaData()
-
-schemata = Table("SCHEMATA", ischema,
-    Column("CATALOG_NAME", String, key="catalog_name"),
-    Column("SCHEMA_NAME", String, key="schema_name"),
-    Column("SCHEMA_OWNER", String, key="schema_owner"),
-    schema="INFORMATION_SCHEMA")
-
-tables = Table("TABLES", ischema,
-    Column("TABLE_CATALOG", String, key="table_catalog"),
-    Column("TABLE_SCHEMA", String, key="table_schema"),
-    Column("TABLE_NAME", String, key="table_name"),
-    Column("TABLE_TYPE", String, key="table_type"),
-    schema="INFORMATION_SCHEMA")
-
-columns = Table("COLUMNS", ischema,
-    Column("TABLE_SCHEMA", String, key="table_schema"),
-    Column("TABLE_NAME", String, key="table_name"),
-    Column("COLUMN_NAME", String, key="column_name"),
-    Column("IS_NULLABLE", Integer, key="is_nullable"),
-    Column("DATA_TYPE", String, key="data_type"),
-    Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
-    Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
-    Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
-    Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
-    Column("COLUMN_DEFAULT", Integer, key="column_default"),
-    Column("COLLATION_NAME", String, key="collation_name"),
-    schema="INFORMATION_SCHEMA")
-
-constraints = Table("TABLE_CONSTRAINTS", ischema,
-    Column("TABLE_SCHEMA", String, key="table_schema"),
-    Column("TABLE_NAME", String, key="table_name"),
-    Column("CONSTRAINT_NAME", String, key="constraint_name"),
-    Column("CONSTRAINT_TYPE", String, key="constraint_type"),
-    schema="INFORMATION_SCHEMA")
-
-column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
-    Column("TABLE_SCHEMA", String, key="table_schema"),
-    Column("TABLE_NAME", String, key="table_name"),
-    Column("COLUMN_NAME", String, key="column_name"),
-    Column("CONSTRAINT_NAME", String, key="constraint_name"),
-    schema="INFORMATION_SCHEMA")
-
-key_constraints = Table("KEY_COLUMN_USAGE", ischema,
-    Column("TABLE_SCHEMA", String, key="table_schema"),
-    Column("TABLE_NAME", String, key="table_name"),
-    Column("COLUMN_NAME", String, key="column_name"),
-    Column("CONSTRAINT_NAME", String, key="constraint_name"),
-    Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
-    schema="INFORMATION_SCHEMA")
-
-ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
-    Column("CONSTRAINT_CATALOG", String, key="constraint_catalog"),
-    Column("CONSTRAINT_SCHEMA", String, key="constraint_schema"),
-    Column("CONSTRAINT_NAME", String, key="constraint_name"),
-    Column("UNIQUE_CONSTRAINT_CATLOG", String, key="unique_constraint_catalog"),
-    Column("UNIQUE_CONSTRAINT_SCHEMA", String, key="unique_constraint_schema"),
-    Column("UNIQUE_CONSTRAINT_NAME", String, key="unique_constraint_name"),
-    Column("MATCH_OPTION", String, key="match_option"),
-    Column("UPDATE_RULE", String, key="update_rule"),
-    Column("DELETE_RULE", String, key="delete_rule"),
-    schema="INFORMATION_SCHEMA")
-
-def _has_implicit_sequence(column):
-    return column.primary_key and  \
-        column.autoincrement and \
-        isinstance(column.type, sqltypes.Integer) and \
-        not column.foreign_keys and \
-        (
-            column.default is None or
-            (
-                isinstance(column.default, schema.Sequence) and
-                column.default.optional)
-            )
-
-def _table_sequence_column(tbl):
-    if not hasattr(tbl, '_ms_has_sequence'):
-        tbl._ms_has_sequence = None
-        for column in tbl.c:
-            if getattr(column, 'sequence', False) or _has_implicit_sequence(column):
-                tbl._ms_has_sequence = column
-                break
-    return tbl._ms_has_sequence
-
-class MSSQLExecutionContext(default.DefaultExecutionContext):
-    IINSERT = False
-    HASIDENT = False
-
-    def pre_exec(self):
-        """Activate IDENTITY_INSERT if needed."""
-
-        if self.compiled.isinsert:
-            tbl = self.compiled.statement.table
-            seq_column = _table_sequence_column(tbl)
-            self.HASIDENT = bool(seq_column)
-            if self.dialect.auto_identity_insert and self.HASIDENT:
-                self.IINSERT = tbl._ms_has_sequence.key in self.compiled_parameters[0]
-            else:
-                self.IINSERT = False
-
-            if self.IINSERT:
-                self.cursor.execute("SET IDENTITY_INSERT %s ON" %
-                    self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-
-    def handle_dbapi_exception(self, e):
-        if self.IINSERT:
-            try:
-                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-            except:
-                pass
-
-    def post_exec(self):
-        """Disable IDENTITY_INSERT if enabled."""
-
-        if self.compiled.isinsert and not self.executemany and self.HASIDENT and not self.IINSERT:
-            if not self._last_inserted_ids or self._last_inserted_ids[0] is None:
-                if self.dialect.use_scope_identity:
-                    self.cursor.execute("SELECT scope_identity() AS lastrowid")
-                else:
-                    self.cursor.execute("SELECT @@identity AS lastrowid")
-                row = self.cursor.fetchone()
-                self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
-
-        if self.IINSERT:
-            self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-
-
-class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
-    def pre_exec(self):
-        """where appropriate, issue "select scope_identity()" in the same statement"""
-        super(MSSQLExecutionContext_pyodbc, self).pre_exec()
-        if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
-                and len(self.parameters) == 1 and self.dialect.use_scope_identity:
-            self.statement += "; select scope_identity()"
-
-    def post_exec(self):
-        if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
-            import pyodbc
-            # Fetch the last inserted id from the manipulated statement
-            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
-            while True:
-                try:
-                    row = self.cursor.fetchone()
-                    break
-                except pyodbc.Error, e:
-                    self.cursor.nextset()
-            self._last_inserted_ids = [int(row[0])]
-        else:
-            super(MSSQLExecutionContext_pyodbc, self).post_exec()
-
-class MSSQLDialect(default.DefaultDialect):
-    name = 'mssql'
-    supports_default_values = True
-    supports_empty_insert = False
-    auto_identity_insert = True
-    execution_ctx_cls = MSSQLExecutionContext
-    text_as_varchar = False
-    use_scope_identity = False
-    has_window_funcs = False
-    max_identifier_length = 128
-    schema_name = "dbo"
-
-    colspecs = {
-        sqltypes.Unicode : MSNVarchar,
-        sqltypes.Integer : MSInteger,
-        sqltypes.Smallinteger: MSSmallInteger,
-        sqltypes.Numeric : MSNumeric,
-        sqltypes.Float : MSFloat,
-        sqltypes.DateTime : MSDateTime,
-        sqltypes.Date : MSDate,
-        sqltypes.Time : MSTime,
-        sqltypes.String : MSString,
-        sqltypes.Binary : MSGenericBinary,
-        sqltypes.Boolean : MSBoolean,
-        sqltypes.Text : MSText,
-        sqltypes.UnicodeText : MSNText,
-        sqltypes.CHAR: MSChar,
-        sqltypes.NCHAR: MSNChar,
-        sqltypes.TIMESTAMP: MSTimeStamp,
-    }
-
-    ischema_names = {
-        'int' : MSInteger,
-        'bigint': MSBigInteger,
-        'smallint' : MSSmallInteger,
-        'tinyint' : MSTinyInteger,
-        'varchar' : MSString,
-        'nvarchar' : MSNVarchar,
-        'char' : MSChar,
-        'nchar' : MSNChar,
-        'text' : MSText,
-        'ntext' : MSNText,
-        'decimal' : MSNumeric,
-        'numeric' : MSNumeric,
-        'float' : MSFloat,
-        'datetime' : MSDateTime,
-        'datetime2' : MSDateTime2,
-        'datetimeoffset' : MSDateTimeOffset,
-        'date': MSDate,
-        'time': MSTime,
-        'smalldatetime' : MSSmallDateTime,
-        'binary' : MSBinary,
-        'varbinary' : MSVarBinary,
-        'bit': MSBoolean,
-        'real' : MSFloat,
-        'image' : MSImage,
-        'timestamp': MSTimeStamp,
-        'money': MSMoney,
-        'smallmoney': MSSmallMoney,
-        'uniqueidentifier': MSUniqueIdentifier,
-        'sql_variant': MSVariant,
-    }
-
-    def __new__(cls, *args, **kwargs):
-        if cls is not MSSQLDialect:
-            # this gets called with the dialect specific class
-            return super(MSSQLDialect, cls).__new__(cls)
-        dbapi = kwargs.get('dbapi', None)
-        if dbapi:
-            dialect = dialect_mapping.get(dbapi.__name__)
-            return dialect(**kwargs)
-        else:
-            return object.__new__(cls)
-
-    def __init__(self,
-                 auto_identity_insert=True, query_timeout=None,
-                 text_as_varchar=False, use_scope_identity=False,
-                 has_window_funcs=False, max_identifier_length=None,
-                 schema_name="dbo", **opts):
-        self.auto_identity_insert = bool(auto_identity_insert)
-        self.query_timeout = int(query_timeout or 0)
-        self.schema_name = schema_name
-
-        # to-do: the options below should use server version introspection to set themselves on connection
-        self.text_as_varchar = bool(text_as_varchar)
-        self.use_scope_identity = bool(use_scope_identity)
-        self.has_window_funcs =  bool(has_window_funcs)
-        self.max_identifier_length = int(max_identifier_length or 0) or \
-                self.max_identifier_length
-        super(MSSQLDialect, self).__init__(**opts)
-
-    @classmethod
-    def dbapi(cls, module_name=None):
-        if module_name:
-            try:
-                dialect_cls = dialect_mapping[module_name]
-                return dialect_cls.import_dbapi()
-            except KeyError:
-                raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
-        else:
-            for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
-                try:
-                    return dialect_cls.import_dbapi()
-                except ImportError, e:
-                    pass
-            else:
-                raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
-
-    @base.connection_memoize(('mssql', 'server_version_info'))
-    def server_version_info(self, connection):
-        """A tuple of the database server version.
-
-        Formats the remote server version as a tuple of version values,
-        e.g. ``(9, 0, 1399)``.  If there are strings in the version number
-        they will be in the tuple too, so don't count on these all being
-        ``int`` values.
-
-        This is a fast check that does not require a round trip.  It is also
-        cached per-Connection.
-        """
-        return connection.dialect._server_version_info(connection.connection)
-
-    def _server_version_info(self, dbapi_con):
-        """Return a tuple of the database's version number."""
-        raise NotImplementedError()
-
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        opts.update(url.query)
-        if 'auto_identity_insert' in opts:
-            self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
-        if 'query_timeout' in opts:
-            self.query_timeout = int(opts.pop('query_timeout'))
-        if 'text_as_varchar' in opts:
-            self.text_as_varchar = bool(int(opts.pop('text_as_varchar')))
-        if 'use_scope_identity' in opts:
-            self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
-        if 'has_window_funcs' in opts:
-            self.has_window_funcs =  bool(int(opts.pop('has_window_funcs')))
-        return self.make_connect_string(opts, url.query)
-
-    def type_descriptor(self, typeobj):
-        newobj = sqltypes.adapt_type(typeobj, self.colspecs)
-        # Some types need to know about the dialect
-        if isinstance(newobj, (MSText, MSNText)):
-            newobj.dialect = self
-        return newobj
-
-    def do_savepoint(self, connection, name):
-        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
-        connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
-        connection.execute("SAVE TRANSACTION %s" % name)
-
-    def do_release_savepoint(self, connection, name):
-        pass
-
-    @base.connection_memoize(('dialect', 'default_schema_name'))
-    def get_default_schema_name(self, connection):
-        query = "SELECT user_name() as user_name;"
-        user_name = connection.scalar(sql.text(query))
-        if user_name is not None:
-            # now, get the default schema
-            query = """
-            SELECT default_schema_name FROM
-            sys.database_principals
-            WHERE name = :user_name
-            AND type = 'S'
-            """
-            try:
-                default_schema_name = connection.scalar(sql.text(query),
-                                                    user_name=user_name)
-                if default_schema_name is not None:
-                    return default_schema_name
-            except:
-                pass
-        return self.schema_name
-
-    def table_names(self, connection, schema):
-        s = select([tables.c.table_name], tables.c.table_schema==schema)
-        return [row[0] for row in connection.execute(s)]
-
-
-    def has_table(self, connection, tablename, schema=None):
-
-        current_schema = schema or self.get_default_schema_name(connection)
-        s = sql.select([columns],
-                   current_schema
-                       and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
-                       or columns.c.table_name==tablename,
-                   )
-
-        c = connection.execute(s)
-        row  = c.fetchone()
-        return row is not None
-
-    def reflecttable(self, connection, table, include_columns):
-        # Get base columns
-        if table.schema is not None:
-            current_schema = table.schema
-        else:
-            current_schema = self.get_default_schema_name(connection)
-
-        s = sql.select([columns],
-                   current_schema
-                       and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema)
-                       or columns.c.table_name==table.name,
-                   order_by=[columns.c.ordinal_position])
-
-        c = connection.execute(s)
-        found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            found_table = True
-            (name, type, nullable, charlen, numericprec, numericscale, default, collation) = (
-                row[columns.c.column_name],
-                row[columns.c.data_type],
-                row[columns.c.is_nullable] == 'YES',
-                row[columns.c.character_maximum_length],
-                row[columns.c.numeric_precision],
-                row[columns.c.numeric_scale],
-                row[columns.c.column_default],
-                row[columns.c.collation_name]
-            )
-            if include_columns and name not in include_columns:
-                continue
-
-            coltype = self.ischema_names.get(type, None)
-
-            kwargs = {}
-            if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary):
-                kwargs['length'] = charlen
-                if collation:
-                    kwargs['collation'] = collation
-                if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1):
-                    kwargs.pop('length')
-
-            if issubclass(coltype, sqltypes.Numeric):
-                kwargs['scale'] = numericscale
-                kwargs['precision'] = numericprec
-
-            if coltype is None:
-                util.warn("Did not recognize type '%s' of column '%s'" % (type, name))
-                coltype = sqltypes.NULLTYPE
-
-            coltype = coltype(**kwargs)
-            colargs = []
-            if default is not None:
-                colargs.append(schema.DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
-
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
-
-        # We also run an sp_columns to check for identity columns:
-        cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema))
-        ic = None
-        while True:
-            row = cursor.fetchone()
-            if row is None:
-                break
-            col_name, type_name = row[3], row[5]
-            if type_name.endswith("identity") and col_name in table.c:
-                ic = table.c[col_name]
-                ic.autoincrement = True
-                # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
-                ic.sequence = schema.Sequence(ic.name + '_identity', 1, 1)
-                # MSSQL: only one identity per table allowed
-                cursor.close()
-                break
-        if not ic is None:
-            try:
-                cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname)
-                row = cursor.fetchone()
-                cursor.close()
-                if not row is None:
-                    ic.sequence.start = int(row[0])
-                    ic.sequence.increment = int(row[1])
-            except:
-                # ignoring it, works just like before
-                pass
-
-        # Add constraints
-        RR = ref_constraints
-        TC = constraints
-        C  = key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column
-        R  = key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
-
-        # Primary key constraints
-        s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name,
-                                                                         C.c.table_name == table.name,
-                                                                         C.c.table_schema == (table.schema or current_schema)))
-        c = connection.execute(s)
-        for row in c:
-            if 'PRIMARY' in row[TC.c.constraint_type.name] and row[0] in table.c:
-                table.primary_key.add(table.c[row[0]])
-
-        # Foreign key constraints
-        s = sql.select([C.c.column_name,
-                        R.c.table_schema, R.c.table_name, R.c.column_name,
-                        RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
-                       sql.and_(C.c.table_name == table.name,
-                                C.c.table_schema == (table.schema or current_schema),
-                                C.c.constraint_name == RR.c.constraint_name,
-                                R.c.constraint_name == RR.c.unique_constraint_name,
-                                C.c.ordinal_position == R.c.ordinal_position
-                                ),
-                       order_by = [RR.c.constraint_name, R.c.ordinal_position])
-        rows = connection.execute(s).fetchall()
-
-        def _gen_fkref(table, rschema, rtbl, rcol):
-            if rschema == current_schema and not table.schema:
-                return '.'.join([rtbl, rcol])
-            else:
-                return '.'.join([rschema, rtbl, rcol])
-
-        # group rows by constraint ID, to handle multi-column FKs
-        fknm, scols, rcols = (None, [], [])
-        for r in rows:
-            scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
-            # if the reflected schema is the default schema then don't set it because this will
-            # play into the metadata key causing duplicates.
-            if rschema == current_schema and not table.schema:
-                schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection)
-            else:
-                schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection)
-            if rfknm != fknm:
-                if fknm:
-                    table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
-                fknm, scols, rcols = (rfknm, [], [])
-            if not scol in scols:
-                scols.append(scol)
-            if not (rschema, rtbl, rcol) in rcols:
-                rcols.append((rschema, rtbl, rcol))
-
-        if fknm and scols:
-            table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
-
-
-class MSSQLDialect_pymssql(MSSQLDialect):
-    supports_sane_rowcount = False
-    max_identifier_length = 30
-
-    @classmethod
-    def import_dbapi(cls):
-        import pymssql as module
-        # pymmsql doesn't have a Binary method.  we use string
-        # TODO: monkeypatching here is less than ideal
-        module.Binary = lambda st: str(st)
-        try:
-            module.version_info = tuple(map(int, module.__version__.split('.')))
-        except:
-            module.version_info = (0, 0, 0)
-        return module
-
-    def __init__(self, **params):
-        super(MSSQLDialect_pymssql, self).__init__(**params)
-        self.use_scope_identity = True
-
-        # pymssql understands only ascii
-        if self.convert_unicode:
-            util.warn("pymssql does not support unicode")
-            self.encoding = params.get('encoding', 'ascii')
-
-        self.colspecs = MSSQLDialect.colspecs.copy()
-        self.ischema_names = MSSQLDialect.ischema_names.copy()
-        self.ischema_names['date'] = MSDateTimeAsDate
-        self.colspecs[sqltypes.Date] = MSDateTimeAsDate
-        self.ischema_names['time'] = MSDateTimeAsTime
-        self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
-    def create_connect_args(self, url):
-        r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
-        if hasattr(self, 'query_timeout'):
-            if self.dbapi.version_info > (0, 8, 0):
-                r[1]['timeout'] = self.query_timeout
-            else:
-                self.dbapi._mssql.set_query_timeout(self.query_timeout)
-        return r
-
-    def make_connect_string(self, keys, query):
-        if keys.get('port'):
-            # pymssql expects port as host:port, not a separate arg
-            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
-            del keys['port']
-        return [[], keys]
-
-    def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
-
-    def do_begin(self, connection):
-        pass
-
-
-class MSSQLDialect_pyodbc(MSSQLDialect):
-    supports_sane_rowcount = False
-    supports_sane_multi_rowcount = False
-    # PyODBC unicode is broken on UCS-4 builds
-    supports_unicode = sys.maxunicode == 65535
-    supports_unicode_statements = supports_unicode
-    execution_ctx_cls = MSSQLExecutionContext_pyodbc
-
-    def __init__(self, description_encoding='latin-1', **params):
-        super(MSSQLDialect_pyodbc, self).__init__(**params)
-        self.description_encoding = description_encoding
-
-        if self.server_version_info < (10,):
-            self.colspecs = MSSQLDialect.colspecs.copy()
-            self.ischema_names = MSSQLDialect.ischema_names.copy()
-            self.ischema_names['date'] = MSDateTimeAsDate
-            self.colspecs[sqltypes.Date] = MSDateTimeAsDate
-            self.ischema_names['time'] = MSDateTimeAsTime
-            self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
-        # FIXME: scope_identity sniff should look at server version, not the ODBC driver
-        # whether use_scope_identity will work depends on the version of pyodbc
-        try:
-            import pyodbc
-            self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
-        except:
-            pass
-
-    @classmethod
-    def import_dbapi(cls):
-        import pyodbc as module
-        return module
-
-    def make_connect_string(self, keys, query):
-        if 'max_identifier_length' in keys:
-            self.max_identifier_length = int(keys.pop('max_identifier_length'))
-
-        if 'odbc_connect' in keys:
-            connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
-        else:
-            dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
-            if dsn_connection:
-                connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
-            else:
-                port = ''
-                if 'port' in keys and not 'port' in query:
-                    port = ',%d' % int(keys.pop('port'))
-
-                connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
-                              'Server=%s%s' % (keys.pop('host', ''), port),
-                              'Database=%s' % keys.pop('database', '') ]
-
-            user = keys.pop("user", None)
-            if user:
-                connectors.append("UID=%s" % user)
-                connectors.append("PWD=%s" % keys.pop('password', ''))
-            else:
-                connectors.append("TrustedConnection=Yes")
-
-            # if set to 'Yes', the ODBC layer will try to automagically convert
-            # textual data from your database encoding to your client encoding
-            # This should obviously be set to 'No' if you query a cp1253 encoded
-            # database from a latin1 client...
-            if 'odbc_autotranslate' in keys:
-                connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
-
-            connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
-
-        return [[";".join (connectors)], {}]
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.ProgrammingError):
-            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
-        elif isinstance(e, self.dbapi.Error):
-            return '[08S01]' in str(e)
-        else:
-            return False
-
-
-    def _server_version_info(self, dbapi_con):
-        """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
-        version = []
-        r = re.compile('[.\-]')
-        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
-            try:
-                version.append(int(n))
-            except ValueError:
-                version.append(n)
-        return tuple(version)
-
-class MSSQLDialect_adodbapi(MSSQLDialect):
-    supports_sane_rowcount = True
-    supports_sane_multi_rowcount = True
-    supports_unicode = sys.maxunicode == 65535
-    supports_unicode_statements = True
-
-    @classmethod
-    def import_dbapi(cls):
-        import adodbapi as module
-        return module
-
-    colspecs = MSSQLDialect.colspecs.copy()
-    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
-
-    ischema_names = MSSQLDialect.ischema_names.copy()
-    ischema_names['datetime'] = MSDateTime_adodbapi
-
-    def make_connect_string(self, keys, query):
-        connectors = ["Provider=SQLOLEDB"]
-        if 'port' in keys:
-            connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
-        else:
-            connectors.append ("Data Source=%s" % keys.get("host"))
-        connectors.append ("Initial Catalog=%s" % keys.get("database"))
-        user = keys.get("user")
-        if user:
-            connectors.append("User Id=%s" % user)
-            connectors.append("Password=%s" % keys.get("password", ""))
-        else:
-            connectors.append("Integrated Security=SSPI")
-        return [[";".join (connectors)], {}]
-
-    def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
-
-
-dialect_mapping = {
-    'pymssql':  MSSQLDialect_pymssql,
-    'pyodbc':   MSSQLDialect_pyodbc,
-    'adodbapi': MSSQLDialect_adodbapi
-    }
-
-
-class MSSQLCompiler(compiler.DefaultCompiler):
-    operators = compiler.OPERATORS.copy()
-    operators.update({
-        sql_operators.concat_op: '+',
-        sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
-    })
-
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now: 'CURRENT_TIMESTAMP',
-            sql_functions.current_date: 'GETDATE()',
-            'length': lambda x: "LEN(%s)" % x,
-            sql_functions.char_length: lambda x: "LEN(%s)" % x
-        }
-    )
-
-    extract_map = compiler.DefaultCompiler.extract_map.copy()
-    extract_map.update ({
-        'doy': 'dayofyear',
-        'dow': 'weekday',
-        'milliseconds': 'millisecond',
-        'microseconds': 'microsecond'
-    })
-
-    def __init__(self, *args, **kwargs):
-        super(MSSQLCompiler, self).__init__(*args, **kwargs)
-        self.tablealiases = {}
-
-    def get_select_precolumns(self, select):
-        """ MS-SQL puts TOP, it's version of LIMIT here """
-        if select._distinct or select._limit:
-            s = select._distinct and "DISTINCT " or ""
-
-            if select._limit:
-                if not select._offset:
-                    s += "TOP %s " % (select._limit,)
-                else:
-                    if not self.dialect.has_window_funcs:
-                        raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
-            return s
-        return compiler.DefaultCompiler.get_select_precolumns(self, select)
-
-    def limit_clause(self, select):
-        # Limit in mssql is after the select keyword
-        return ""
-
-    def visit_select(self, select, **kwargs):
-        """Look for ``LIMIT`` and OFFSET in a select statement, and if
-        so tries to wrap it in a subquery with ``row_number()`` criterion.
-
-        """
-        if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
-            # to use ROW_NUMBER(), an ORDER BY is required.
-            orderby = self.process(select._order_by_clause)
-            if not orderby:
-                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
-
-            _offset = select._offset
-            _limit = select._limit
-            select._mssql_visit = True
-            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
-
-            limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
-            limitselect.append_whereclause("mssql_rn>%d" % _offset)
-            if _limit is not None:
-                limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
-            return self.process(limitselect, iswrapper=True, **kwargs)
-        else:
-            return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
-
-    def _schema_aliased_table(self, table):
-        if getattr(table, 'schema', None) is not None:
-            if table not in self.tablealiases:
-                self.tablealiases[table] = table.alias()
-            return self.tablealiases[table]
-        else:
-            return None
-
-    def visit_table(self, table, mssql_aliased=False, **kwargs):
-        if mssql_aliased:
-            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
-        # alias schema-qualified tables
-        alias = self._schema_aliased_table(table)
-        if alias is not None:
-            return self.process(alias, mssql_aliased=True, **kwargs)
-        else:
-            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
-    def visit_alias(self, alias, **kwargs):
-        # translate for schema-qualified table aliases
-        self.tablealiases[alias.original] = alias
-        kwargs['mssql_aliased'] = True
-        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
-
-    def visit_extract(self, extract):
-        field = self.extract_map.get(extract.field, extract.field)
-        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
-
-    def visit_rollback_to_savepoint(self, savepoint_stmt):
-        return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
-    def visit_column(self, column, result_map=None, **kwargs):
-        if column.table is not None and \
-            (not self.isupdate and not self.isdelete) or self.is_subquery():
-            # translate for schema-qualified table aliases
-            t = self._schema_aliased_table(column.table)
-            if t is not None:
-                converted = expression._corresponding_column_or_error(t, column)
-
-                if result_map is not None:
-                    result_map[column.name.lower()] = (column.name, (column, ), column.type)
-
-                return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
-
-        return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
-
-    def visit_binary(self, binary, **kwargs):
-        """Move bind parameters to the right-hand side of an operator, where
-        possible.
-
-        """
-        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
-            and not isinstance(binary.right, expression._BindParamClause):
-            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
-        else:
-            if (binary.operator is operator.eq or binary.operator is operator.ne) and (
-                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
-                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
-                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
-                op = binary.operator == operator.eq and "IN" or "NOT IN"
-                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
-            return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
-
-    def visit_insert(self, insert_stmt):
-        insert_select = False
-        if insert_stmt.parameters:
-            insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
-        if insert_select:
-            self.isinsert = True
-            colparams = self._get_colparams(insert_stmt)
-            preparer = self.preparer
-
-            insert = ' '.join(["INSERT"] +
-                              [self.process(x) for x in insert_stmt._prefixes])
-
-            if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
-                raise exc.CompileError(
-                    "The version of %s you are using does not support empty inserts." % self.dialect.name)
-            elif not colparams and self.dialect.supports_default_values:
-                return (insert + " INTO %s DEFAULT VALUES" % (
-                    (preparer.format_table(insert_stmt.table),)))
-            else:
-                return (insert + " INTO %s (%s) SELECT %s" %
-                    (preparer.format_table(insert_stmt.table),
-                     ', '.join([preparer.format_column(c[0])
-                               for c in colparams]),
-                     ', '.join([c[1] for c in colparams])))
-        else:
-            return super(MSSQLCompiler, self).visit_insert(insert_stmt)
-
-    def label_select_column(self, select, column, asfrom):
-        if isinstance(column, expression.Function):
-            return column.label(None)
-        else:
-            return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
-
-    def for_update_clause(self, select):
-        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
-        return ''
-
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
-
-        # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
-        if order_by and (not self.is_subquery() or select._limit):
-            return " ORDER BY " + order_by
-        else:
-            return ""
-
-
-class MSSQLSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
-        if column.nullable is not None:
-            if not column.nullable or column.primary_key:
-                colspec += " NOT NULL"
-            else:
-                colspec += " NULL"
-
-        if not column.table:
-            raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
-
-        seq_col = _table_sequence_column(column.table)
-
-        # install a IDENTITY Sequence if we have an implicit IDENTITY column
-        if seq_col is column:
-            sequence = getattr(column, 'sequence', None)
-            if sequence:
-                start, increment = sequence.start or 1, sequence.increment or 1
-            else:
-                start, increment = 1, 1
-            colspec += " IDENTITY(%s,%s)" % (start, increment)
-        else:
-            default = self.get_column_default_string(column)
-            if default is not None:
-                colspec += " DEFAULT " + default
-
-        return colspec
-
-class MSSQLSchemaDropper(compiler.SchemaDropper):
-    def visit_index(self, index):
-        self.append("\nDROP INDEX %s.%s" % (
-            self.preparer.quote_identifier(index.table.name),
-            self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
-            ))
-        self.execute()
-
-
-class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
-    reserved_words = RESERVED_WORDS
-
-    def __init__(self, dialect):
-        super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
-
-    def _escape_identifier(self, value):
-        #TODO: determine MSSQL's escaping rules
-        return value
-
-    def quote_schema(self, schema, force=True):
-        """Prepare a quoted table and schema name."""
-        result = '.'.join([self.quote(x, force) for x in schema.split('.')])
-        return result
-
-dialect = MSSQLDialect
-dialect.statement_compiler = MSSQLCompiler
-dialect.schemagenerator = MSSQLSchemaGenerator
-dialect.schemadropper = MSSQLSchemaDropper
-dialect.preparer = MSSQLIdentifierPreparer
-
diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py
deleted file mode 100644 (file)
index 92f5336..0000000
+++ /dev/null
@@ -1,60 +0,0 @@
-# mxODBC.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-A wrapper for a mx.ODBC.Windows DB-API connection.
-
-Makes sure the mx module is configured to return datetime objects instead
-of mx.DateTime.DateTime objects.
-"""
-
-from mx.ODBC.Windows import *
-
-
-class Cursor:
-    def __init__(self, cursor):
-        self.cursor = cursor
-
-    def __getattr__(self, attr):
-        res = getattr(self.cursor, attr)
-        return res
-
-    def execute(self, *args, **kwargs):
-        res = self.cursor.execute(*args, **kwargs)
-        return res
-
-
-class Connection:
-    def myErrorHandler(self, connection, cursor, errorclass, errorvalue):
-        err0, err1, err2, err3 = errorvalue
-        #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)])
-        if int(err1) == 109:
-            # Ignore "Null value eliminated in aggregate function", this is not an error
-            return
-        raise errorclass, errorvalue
-
-    def __init__(self, conn):
-        self.conn = conn
-        # install a mx ODBC error handler
-        self.conn.errorhandler = self.myErrorHandler
-
-    def __getattr__(self, attr):
-        res = getattr(self.conn, attr)
-        return res
-
-    def cursor(self, *args, **kwargs):
-        res = Cursor(self.conn.cursor(*args, **kwargs))
-        return res
-
-
-# override 'connect' call
-def connect(*args, **kwargs):
-    import mx.ODBC.Windows
-    conn = mx.ODBC.Windows.Connect(*args, **kwargs)
-    conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
-    return Connection(conn)
-Connect = connect
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
deleted file mode 100644 (file)
index 852cab4..0000000
+++ /dev/null
@@ -1,904 +0,0 @@
-# oracle.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Support for the Oracle database.
-
-Oracle version 8 through current (11g at the time of this writing) are supported.
-
-Driver
-------
-
-The Oracle dialect uses the cx_oracle driver, available at 
-http://cx-oracle.sourceforge.net/ .   The dialect has several behaviors 
-which are specifically tailored towards compatibility with this module.
-
-Connecting
-----------
-
-Connecting with create_engine() uses the standard URL approach of 
-``oracle://user:pass@host:port/dbname[?key=value&key=value...]``.  If dbname is present, the 
-host, port, and dbname tokens are converted to a TNS name using the cx_oracle 
-:func:`makedsn()` function.  Otherwise, the host token is taken directly as a TNS name.
-
-Additional arguments which may be specified either as query string arguments on the
-URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
-
-* *allow_twophase* - enable two-phase transactions.  Defaults to ``True``.
-
-* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
-
-* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
-  This is required for LOB datatypes but can be disabled to reduce overhead.  Defaults
-  to ``True``.
-
-* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
-  integer value.  This value is only available as a URL query string argument.
-
-* *threaded* - enable multithreaded access to cx_oracle connections.  Defaults
-  to ``True``.  Note that this is the opposite default of cx_oracle itself.
-
-* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8).  Defaults
-  to ``True``.  If ``False``, Oracle-8 compatible constructs are used for joins.
-
-* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET.
-
-Auto Increment Behavior
------------------------
-
-SQLAlchemy Table objects which include integer primary keys are usually assumed to have
-"autoincrementing" behavior, meaning they can generate their own primary key values upon
-INSERT.  Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences 
-to produce these values.   With the Oracle dialect, *a sequence must always be explicitly
-specified to enable autoincrement*.  This is divergent with the majority of documentation 
-examples which assume the usage of an autoincrement-capable database.   To specify sequences,
-use the sqlalchemy.schema.Sequence object which is passed to a Column construct::
-
-  t = Table('mytable', metadata, 
-        Column('id', Integer, Sequence('id_seq'), primary_key=True),
-        Column(...), ...
-  )
-
-This step is also required when using table reflection, i.e. autoload=True::
-
-  t = Table('mytable', metadata, 
-        Column('id', Integer, Sequence('id_seq'), primary_key=True),
-        autoload=True
-  ) 
-
-LOB Objects
------------
-
-cx_oracle presents some challenges when fetching LOB objects.  A LOB object in a result set
-is presented by cx_oracle as a cx_oracle.LOB object which has a read() method.  By default, 
-SQLAlchemy converts these LOB objects into Python strings.  This is for two reasons.  First,
-the LOB object requires an active cursor association, meaning if you were to fetch many rows
-at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
-the LOB objects in the already-fetched rows are now unreadable and will raise an error. 
-SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.  
-The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
-defaults to 50 (cx_oracle normally defaults this to one).  
-
-Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to 
-"normalize" the results to look more like other DBAPIs.
-
-The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
-for all statement executions, even plain string-based statements for which SQLA has no awareness
-of result typing.  This is so that calls like fetchmany() and fetchall() can work in all cases
-without raising cursor errors.  The conversion of LOB in all cases, as well as the "prefetch"
-of LOB objects, can be disabled using auto_convert_lobs=False.  
-
-LIMIT/OFFSET Support
---------------------
-
-Oracle has no support for the LIMIT or OFFSET keywords.  Whereas previous versions of SQLAlchemy
-used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses 
-a wrapped subquery approach in conjunction with ROWNUM.  The exact methodology is taken from
-http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html .  Note that the 
-"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt
-this was stepping into the bounds of optimization that is better left on the DBA side, but this
-prefix can be added by enabling the optimize_limits=True flag on create_engine().
-
-Two Phase Transaction Support
------------------------------
-
-Two Phase transactions are implemented using XA transactions.  Success has been reported of them
-working successfully but this should be regarded as an experimental feature.
-
-Oracle 8 Compatibility
-----------------------
-
-When using Oracle 8, a "use_ansi=False" flag is available which converts all
-JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN
-makes use of Oracle's (+) operator.
-
-Synonym/DBLINK Reflection
--------------------------
-
-When using reflection with Table objects, the dialect can optionally search for tables
-indicated by synonyms that reference DBLINK-ed tables by passing the flag 
-oracle_resolve_synonyms=True as a keyword argument to the Table construct.  If DBLINK 
-is not in use this flag should be left off.
-
-"""
-
-import datetime, random, re
-
-from sqlalchemy import util, sql, schema, log
-from sqlalchemy.engine import default, base
-from sqlalchemy.sql import compiler, visitors, expression
-from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
-from sqlalchemy import types as sqltypes
-
-
-class OracleNumeric(sqltypes.Numeric):
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class OracleInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class OracleSmallInteger(sqltypes.Smallinteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class OracleDate(sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
-        def process(value):
-            if not isinstance(value, datetime.datetime):
-                return value
-            else:
-                return value.date()
-        return process
-
-class OracleDateTime(sqltypes.DateTime):
-    def get_col_spec(self):
-        return "DATE"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None or isinstance(value, datetime.datetime):
-                return value
-            else:
-                # convert cx_oracle datetime object returned pre-python 2.4
-                return datetime.datetime(value.year, value.month,
-                    value.day,value.hour, value.minute, value.second)
-        return process
-
-# Note:
-# Oracle DATE == DATETIME
-# Oracle does not allow milliseconds in DATE
-# Oracle does not support TIME columns
-
-# only if cx_oracle contains TIMESTAMP
-class OracleTimestamp(sqltypes.TIMESTAMP):
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-    def get_dbapi_type(self, dialect):
-        return dialect.TIMESTAMP
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None or isinstance(value, datetime.datetime):
-                return value
-            else:
-                # convert cx_oracle datetime object returned pre-python 2.4
-                return datetime.datetime(value.year, value.month,
-                    value.day,value.hour, value.minute, value.second)
-        return process
-
-class OracleString(sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-class OracleNVarchar(sqltypes.Unicode, OracleString):
-    def get_col_spec(self):
-        return "NVARCHAR2(%(length)s)" % {'length' : self.length}
-
-class OracleText(sqltypes.Text):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.CLOB
-
-    def get_col_spec(self):
-        return "CLOB"
-
-    def result_processor(self, dialect):
-        super_process = super(OracleText, self).result_processor(dialect)
-        if not dialect.auto_convert_lobs:
-            return super_process
-        lob = dialect.dbapi.LOB
-        def process(value):
-            if isinstance(value, lob):
-                if super_process:
-                    return super_process(value.read())
-                else:
-                    return value.read()
-            else:
-                if super_process:
-                    return super_process(value)
-                else:
-                    return value
-        return process
-
-
-class OracleChar(sqltypes.CHAR):
-    def get_col_spec(self):
-        return "CHAR(%(length)s)" % {'length' : self.length}
-
-class OracleBinary(sqltypes.Binary):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.BLOB
-
-    def get_col_spec(self):
-        return "BLOB"
-
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
-        if not dialect.auto_convert_lobs:
-            return None
-        lob = dialect.dbapi.LOB
-        def process(value):
-            if isinstance(value, lob):
-                return value.read()
-            else:
-                return value
-        return process
-
-class OracleRaw(OracleBinary):
-    def get_col_spec(self):
-        return "RAW(%(length)s)" % {'length' : self.length}
-
-class OracleBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is True:
-                return 1
-            elif value is False:
-                return 0
-            elif value is None:
-                return None
-            else:
-                return value and True or False
-        return process
-
-colspecs = {
-    sqltypes.Integer : OracleInteger,
-    sqltypes.Smallinteger : OracleSmallInteger,
-    sqltypes.Numeric : OracleNumeric,
-    sqltypes.Float : OracleNumeric,
-    sqltypes.DateTime : OracleDateTime,
-    sqltypes.Date : OracleDate,
-    sqltypes.String : OracleString,
-    sqltypes.Binary : OracleBinary,
-    sqltypes.Boolean : OracleBoolean,
-    sqltypes.Text : OracleText,
-    sqltypes.TIMESTAMP : OracleTimestamp,
-    sqltypes.CHAR: OracleChar,
-}
-
-ischema_names = {
-    'VARCHAR2' : OracleString,
-    'NVARCHAR2' : OracleNVarchar,
-    'CHAR' : OracleString,
-    'DATE' : OracleDateTime,
-    'DATETIME' : OracleDateTime,
-    'NUMBER' : OracleNumeric,
-    'BLOB' : OracleBinary,
-    'BFILE' : OracleBinary,
-    'CLOB' : OracleText,
-    'TIMESTAMP' : OracleTimestamp,
-    'RAW' : OracleRaw,
-    'FLOAT' : OracleNumeric,
-    'DOUBLE PRECISION' : OracleNumeric,
-    'LONG' : OracleText,
-}
-
-class OracleExecutionContext(default.DefaultExecutionContext):
-    def pre_exec(self):
-        super(OracleExecutionContext, self).pre_exec()
-        if self.dialect.auto_setinputsizes:
-            self.set_input_sizes()
-        if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
-            for key in self.compiled.binds:
-                bindparam = self.compiled.binds[key]
-                name = self.compiled.bind_names[bindparam]
-                value = self.compiled_parameters[0][name]
-                if bindparam.isoutparam:
-                    dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
-                    if not hasattr(self, 'out_parameters'):
-                        self.out_parameters = {}
-                    self.out_parameters[name] = self.cursor.var(dbtype)
-                    self.parameters[0][name] = self.out_parameters[name]
-
-    def create_cursor(self):
-        c = self._connection.connection.cursor()
-        if self.dialect.arraysize:
-            c.cursor.arraysize = self.dialect.arraysize
-        return c
-
-    def get_result_proxy(self):
-        if hasattr(self, 'out_parameters'):
-            if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
-                for bind, name in self.compiled.bind_names.iteritems():
-                    if name in self.out_parameters:
-                        type = bind.type
-                        result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
-                        if result_processor is not None:
-                            self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
-                        else:
-                            self.out_parameters[name] = self.out_parameters[name].getvalue()
-            else:
-                for k in self.out_parameters:
-                    self.out_parameters[k] = self.out_parameters[k].getvalue()
-
-        if self.cursor.description is not None:
-            for column in self.cursor.description:
-                type_code = column[1]
-                if type_code in self.dialect.ORACLE_BINARY_TYPES:
-                    return base.BufferedColumnResultProxy(self)
-
-        return base.ResultProxy(self)
-
-class OracleDialect(default.DefaultDialect):
-    name = 'oracle'
-    supports_alter = True
-    supports_unicode_statements = False
-    max_identifier_length = 30
-    supports_sane_rowcount = True
-    supports_sane_multi_rowcount = False
-    preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
-    default_paramstyle = 'named'
-
-    def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, optimize_limits=False, arraysize=50, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-        self.use_ansi = use_ansi
-        self.threaded = threaded
-        self.arraysize = arraysize
-        self.allow_twophase = allow_twophase
-        self.optimize_limits = optimize_limits
-        self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
-        self.auto_setinputsizes = auto_setinputsizes
-        self.auto_convert_lobs = auto_convert_lobs
-        if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
-            self.dbapi_type_map = {}
-            self.ORACLE_BINARY_TYPES = []
-        else:
-            # only use this for LOB objects.  using it for strings, dates
-            # etc. leads to a little too much magic, reflection doesn't know if it should
-            # expect encoded strings or unicodes, etc.
-            self.dbapi_type_map = {
-                self.dbapi.CLOB: OracleText(),
-                self.dbapi.BLOB: OracleBinary(),
-                self.dbapi.BINARY: OracleRaw(),
-            }
-            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
-
-    def dbapi(cls):
-        import cx_Oracle
-        return cx_Oracle
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        dialect_opts = dict(url.query)
-        for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
-                    'threaded', 'allow_twophase'):
-            if opt in dialect_opts:
-                util.coerce_kw_type(dialect_opts, opt, bool)
-                setattr(self, opt, dialect_opts[opt])
-
-        if url.database:
-            # if we have a database, then we have a remote host
-            port = url.port
-            if port:
-                port = int(port)
-            else:
-                port = 1521
-            dsn = self.dbapi.makedsn(url.host, port, url.database)
-        else:
-            # we have a local tnsname
-            dsn = url.host
-
-        opts = dict(
-            user=url.username,
-            password=url.password,
-            dsn=dsn,
-            threaded=self.threaded,
-            twophase=self.allow_twophase,
-            )
-        if 'mode' in url.query:
-            opts['mode'] = url.query['mode']
-            if isinstance(opts['mode'], basestring):
-                mode = opts['mode'].upper()
-                if mode == 'SYSDBA':
-                    opts['mode'] = self.dbapi.SYSDBA
-                elif mode == 'SYSOPER':
-                    opts['mode'] = self.dbapi.SYSOPER
-                else:
-                    util.coerce_kw_type(opts, 'mode', int)
-        # Can't set 'handle' or 'pool' via URL query args, use connect_args
-
-        return ([], opts)
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.InterfaceError):
-            return "not connected" in str(e)
-        else:
-            return "ORA-03114" in str(e) or "ORA-03113" in str(e)
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
-    def create_xid(self):
-        """create a two-phase transaction ID.
-
-        this id will be passed to do_begin_twophase(), do_rollback_twophase(),
-        do_commit_twophase().  its format is unspecified."""
-
-        id = random.randint(0, 2 ** 128)
-        return (0x1234, "%032x" % id, "%032x" % 9)
-        
-    def do_release_savepoint(self, connection, name):
-        # Oracle does not support RELEASE SAVEPOINT
-        pass
-
-    def do_begin_twophase(self, connection, xid):
-        connection.connection.begin(*xid)
-
-    def do_prepare_twophase(self, connection, xid):
-        connection.connection.prepare()
-
-    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
-        self.do_rollback(connection.connection)
-
-    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
-        self.do_commit(connection.connection)
-
-    def do_recover_twophase(self, connection):
-        pass
-
-    def has_table(self, connection, table_name, schema=None):
-        if not schema:
-            schema = self.get_default_schema_name(connection)
-        cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)})
-        return cursor.fetchone() is not None
-
-    def has_sequence(self, connection, sequence_name, schema=None):
-        if not schema:
-            schema = self.get_default_schema_name(connection)
-        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)})
-        return cursor.fetchone() is not None
-
-    def _normalize_name(self, name):
-        if name is None:
-            return None
-        elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)):
-            return name.lower().decode(self.encoding)
-        else:
-            return name.decode(self.encoding)
-
-    def _denormalize_name(self, name):
-        if name is None:
-            return None
-        elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
-            return name.upper().encode(self.encoding)
-        else:
-            return name.encode(self.encoding)
-
-    def get_default_schema_name(self, connection):
-        return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
-    get_default_schema_name = base.connection_memoize(
-        ('dialect', 'default_schema_name'))(get_default_schema_name)
-
-    def table_names(self, connection, schema):
-        # note that table_names() isnt loading DBLINKed or synonym'ed tables
-        if schema is None:
-            s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')"
-            cursor = connection.execute(s)
-        else:
-            s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
-            cursor = connection.execute(s, {'owner': self._denormalize_name(schema)})
-        return [self._normalize_name(row[0]) for row in cursor]
-
-    def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
-        """search for a local synonym matching the given desired owner/name.
-
-        if desired_owner is None, attempts to locate a distinct owner.
-
-        returns the actual name, owner, dblink name, and synonym name if found.
-        """
-
-        sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
-                   from   ALL_SYNONYMS WHERE """
-
-        clauses = []
-        params = {}
-        if desired_synonym:
-            clauses.append("SYNONYM_NAME=:synonym_name")
-            params['synonym_name'] = desired_synonym
-        if desired_owner:
-            clauses.append("TABLE_OWNER=:desired_owner")
-            params['desired_owner'] = desired_owner
-        if desired_table:
-            clauses.append("TABLE_NAME=:tname")
-            params['tname'] = desired_table
-
-        sql += " AND ".join(clauses)
-
-        result = connection.execute(sql, **params)
-        if desired_owner:
-            row = result.fetchone()
-            if row:
-                return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
-            else:
-                return None, None, None, None
-        else:
-            rows = result.fetchall()
-            if len(rows) > 1:
-                raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
-            elif len(rows) == 1:
-                row = rows[0]
-                return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
-            else:
-                return None, None, None, None
-
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-
-        resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
-
-        if resolve_synonyms:
-            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
-        else:
-            actual_name, owner, dblink, synonym = None, None, None, None
-
-        if not actual_name:
-            actual_name = self._denormalize_name(table.name)
-        if not dblink:
-            dblink = ''
-        if not owner:
-            owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection))
-
-        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
-
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-
-            (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
-
-            if include_columns and colname not in include_columns:
-                continue
-
-            # INTEGER if the scale is 0 and precision is null
-            # NUMBER if the scale and precision are both null
-            # NUMBER(9,2) if the precision is 9 and the scale is 2
-            # NUMBER(3) if the precision is 3 and scale is 0
-            #length is ignored except for CHAR and VARCHAR2
-            if coltype == 'NUMBER' :
-                if precision is None and scale is None:
-                    coltype = OracleNumeric
-                elif precision is None and scale == 0  :
-                    coltype = OracleInteger
-                else :
-                    coltype = OracleNumeric(precision, scale)
-            elif coltype=='CHAR' or coltype=='VARCHAR2':
-                coltype = ischema_names.get(coltype, OracleString)(length)
-            else:
-                coltype = re.sub(r'\(\d+\)', '', coltype)
-                try:
-                    coltype = ischema_names[coltype]
-                except KeyError:
-                    util.warn("Did not recognize type '%s' of column '%s'" %
-                              (coltype, colname))
-                    coltype = sqltypes.NULLTYPE
-
-            colargs = []
-            if default is not None:
-                colargs.append(schema.DefaultClause(sql.text(default)))
-
-            table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
-
-        if not table.columns:
-            raise AssertionError("Couldn't find any column information for table %s" % actual_name)
-
-        c = connection.execute("""SELECT
-             ac.constraint_name,
-             ac.constraint_type,
-             loc.column_name AS local_column,
-             rem.table_name AS remote_table,
-             rem.column_name AS remote_column,
-             rem.owner AS remote_owner
-           FROM all_constraints%(dblink)s ac,
-             all_cons_columns%(dblink)s loc,
-             all_cons_columns%(dblink)s rem
-           WHERE ac.table_name = :table_name
-           AND ac.constraint_type IN ('R','P')
-           AND ac.owner = :owner
-           AND ac.owner = loc.owner
-           AND ac.constraint_name = loc.constraint_name
-           AND ac.r_owner = rem.owner(+)
-           AND ac.r_constraint_name = rem.constraint_name(+)
-           -- order multiple primary keys correctly
-           ORDER BY ac.constraint_name, loc.position, rem.position"""
-         % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner})
-
-        fks = {}
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            #print "ROW:" , row
-            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
-            if cons_type == 'P':
-                table.primary_key.add(table.c[local_column])
-            elif cons_type == 'R':
-                try:
-                    fk = fks[cons_name]
-                except KeyError:
-                    fk = ([], [])
-                    fks[cons_name] = fk
-                if remote_table is None:
-                    # ticket 363
-                    util.warn(
-                        ("Got 'None' querying 'table_name' from "
-                         "all_cons_columns%(dblink)s - does the user have "
-                         "proper rights to the table?") % {'dblink':dblink})
-                    continue
-
-                if resolve_synonyms:
-                    ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table))
-                    if ref_synonym:
-                        remote_table = self._normalize_name(ref_synonym)
-                        remote_owner = self._normalize_name(ref_remote_owner)
-
-                if not table.schema and self._denormalize_name(remote_owner) == owner:
-                    refspec =  ".".join([remote_table, remote_column])
-                    t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
-                else:
-                    refspec =  ".".join([x for x in [remote_owner, remote_table, remote_column] if x])
-                    t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
-
-                if local_column not in fk[0]:
-                    fk[0].append(local_column)
-                if refspec not in fk[1]:
-                    fk[1].append(refspec)
-
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True))
-
-
-class _OuterJoinColumn(sql.ClauseElement):
-    __visit_name__ = 'outer_join_column'
-    
-    def __init__(self, column):
-        self.column = column
-
-class OracleCompiler(compiler.DefaultCompiler):
-    """Oracle compiler modifies the lexical structure of Select
-    statements to work under non-ANSI configured Oracle databases, if
-    the use_ansi flag is False.
-    """
-
-    operators = compiler.DefaultCompiler.operators.copy()
-    operators.update(
-        {
-            sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
-            sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
-        }
-    )
-
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now : 'CURRENT_TIMESTAMP'
-        }
-    )
-
-    def __init__(self, *args, **kwargs):
-        super(OracleCompiler, self).__init__(*args, **kwargs)
-        self.__wheres = {}
-
-    def default_from(self):
-        """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
-
-        The Oracle compiler tacks a "FROM DUAL" to the statement.
-        """
-
-        return " FROM DUAL"
-
-    def apply_function_parens(self, func):
-        return len(func.clauses) > 0
-
-    def visit_join(self, join, **kwargs):
-        if self.dialect.use_ansi:
-            return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
-        else:
-            return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-
-    def _get_nonansi_join_whereclause(self, froms):
-        clauses = []
-
-        def visit_join(join):
-            if join.isouter:
-                def visit_binary(binary):
-                    if binary.operator == sql_operators.eq:
-                        if binary.left.table is join.right:
-                            binary.left = _OuterJoinColumn(binary.left)
-                        elif binary.right.table is join.right:
-                            binary.right = _OuterJoinColumn(binary.right)
-                clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
-            else:
-                clauses.append(join.onclause)
-
-        for f in froms:
-            visitors.traverse(f, {}, {'join':visit_join})
-        return sql.and_(*clauses)
-
-    def visit_outer_join_column(self, vc):
-        return self.process(vc.column) + "(+)"
-
-    def visit_sequence(self, seq):
-        return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
-
-    def visit_alias(self, alias, asfrom=False, **kwargs):
-        """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
-
-        if asfrom:
-            alias_name = isinstance(alias.name, expression._generated_label) and \
-                            self._truncated_identifier("alias", alias.name) or alias.name
-            
-            return self.process(alias.original, asfrom=True, **kwargs) + " " +\
-                    self.preparer.format_alias(alias, alias_name)
-        else:
-            return self.process(alias.original, **kwargs)
-
-    def _TODO_visit_compound_select(self, select):
-        """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
-        pass
-
-    def visit_select(self, select, **kwargs):
-        """Look for ``LIMIT`` and OFFSET in a select statement, and if
-        so tries to wrap it in a subquery with ``rownum`` criterion.
-        """
-
-        if not getattr(select, '_oracle_visit', None):
-            if not self.dialect.use_ansi:
-                if self.stack and 'from' in self.stack[-1]:
-                    existingfroms = self.stack[-1]['from']
-                else:
-                    existingfroms = None
-
-                froms = select._get_display_froms(existingfroms)
-                whereclause = self._get_nonansi_join_whereclause(froms)
-                if whereclause:
-                    select = select.where(whereclause)
-                    select._oracle_visit = True
-
-            if select._limit is not None or select._offset is not None:
-                # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
-                #
-                # Generalized form of an Oracle pagination query:
-                #   select ... from (
-                #     select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
-                #         select distinct ... where ... order by ...
-                #     ) where ROWNUM <= :limit+:offset
-                #   ) where ora_rn > :offset
-                # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
-
-                # TODO: use annotations instead of clone + attr set ?
-                select = select._generate()
-                select._oracle_visit = True
-
-                # Wrap the middle select and add the hint
-                limitselect = sql.select([c for c in select.c])
-                if select._limit and self.dialect.optimize_limits:
-                    limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
-
-                limitselect._oracle_visit = True
-                limitselect._is_wrapper = True
-
-                # If needed, add the limiting clause
-                if select._limit is not None:
-                    max_row = select._limit
-                    if select._offset is not None:
-                        max_row += select._offset
-                    limitselect.append_whereclause(
-                            sql.literal_column("ROWNUM")<=max_row)
-                # If needed, add the ora_rn, and wrap again with offset.
-                if select._offset is None:
-                    select = limitselect
-                else:
-                     limitselect = limitselect.column(
-                             sql.literal_column("ROWNUM").label("ora_rn"))
-                     limitselect._oracle_visit = True
-                     limitselect._is_wrapper = True
-                     offsetselect = sql.select(
-                             [c for c in limitselect.c if c.key!='ora_rn'])
-                     offsetselect._oracle_visit = True
-                     offsetselect._is_wrapper = True
-                     offsetselect.append_whereclause(
-                             sql.literal_column("ora_rn")>select._offset)
-                     select = offsetselect
-
-        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
-        return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
-
-    def limit_clause(self, select):
-        return ""
-
-    def for_update_clause(self, select):
-        if select.for_update == "nowait":
-            return " FOR UPDATE NOWAIT"
-        else:
-            return super(OracleCompiler, self).for_update_clause(select)
-
-
-class OracleSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
-
-    def visit_sequence(self, sequence):
-        if not self.checkfirst  or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
-            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-class OracleSchemaDropper(compiler.SchemaDropper):
-    def visit_sequence(self, sequence):
-        if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
-            self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-class OracleDefaultRunner(base.DefaultRunner):
-    def visit_sequence(self, seq):
-        return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
-
-class OracleIdentifierPreparer(compiler.IdentifierPreparer):
-    def format_savepoint(self, savepoint):
-        name = re.sub(r'^_+', '', savepoint.ident)
-        return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
-
-
-dialect = OracleDialect
-dialect.statement_compiler = OracleCompiler
-dialect.schemagenerator = OracleSchemaGenerator
-dialect.schemadropper = OracleSchemaDropper
-dialect.preparer = OracleIdentifierPreparer
-dialect.defaultrunner = OracleDefaultRunner
-dialect.execution_ctx_cls = OracleExecutionContext
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
deleted file mode 100644 (file)
index 8952b2b..0000000
+++ /dev/null
@@ -1,646 +0,0 @@
-# sqlite.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""Support for the SQLite database.
-
-Driver
-------
-
-When using Python 2.5 and above, the built in ``sqlite3`` driver is 
-already installed and no additional installation is needed.  Otherwise,
-the ``pysqlite2`` driver needs to be present.  This is the same driver as
-``sqlite3``, just with a different name.
-
-The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
-is loaded.  This allows an explicitly installed pysqlite driver to take
-precedence over the built in one.   As with all dialects, a specific 
-DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control 
-this explicitly::
-
-    from sqlite3 import dbapi2 as sqlite
-    e = create_engine('sqlite:///file.db', module=sqlite)
-
-Full documentation on pysqlite is available at:
-`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
-
-Connect Strings
----------------
-
-The file specification for the SQLite database is taken as the "database" portion of
-the URL.  Note that the format of a url is::
-
-    driver://user:pass@host/database
-    
-This means that the actual filename to be used starts with the characters to the
-**right** of the third slash.   So connecting to a relative filepath looks like::
-
-    # relative path
-    e = create_engine('sqlite:///path/to/database.db')
-    
-An absolute path, which is denoted by starting with a slash, means you need **four**
-slashes::
-
-    # absolute path
-    e = create_engine('sqlite:////path/to/database.db')
-
-To use a Windows path, regular drive specifications and backslashes can be used.  
-Double backslashes are probably needed::
-
-    # absolute path on Windows
-    e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
-
-The sqlite ``:memory:`` identifier is the default if no filepath is present.  Specify
-``sqlite://`` and nothing else::
-
-    # in-memory database
-    e = create_engine('sqlite://')
-
-Threading Behavior
-------------------
-
-Pysqlite connections do not support being moved between threads, unless
-the ``check_same_thread`` Pysqlite flag is set to ``False``.  In addition,
-when using an in-memory SQLite database, the full database exists only within 
-the scope of a single connection.  It is reported that an in-memory
-database does not support being shared between threads regardless of the 
-``check_same_thread`` flag - which means that a multithreaded
-application **cannot** share data from a ``:memory:`` database across threads
-unless access to the connection is limited to a single worker thread which communicates
-through a queueing mechanism to concurrent threads.
-
-To provide a default which accomodates SQLite's default threading capabilities
-somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
-be used by default.  This pool maintains a single SQLite connection per thread
-that is held open up to a count of five concurrent threads.  When more than five threads
-are used, a cleanup mechanism will dispose of excess unused connections.   
-
-Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
-
- * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
-   application using an in-memory database, assuming the threading issues inherent in 
-   pysqlite are somehow accomodated for.  This pool holds persistently onto a single connection
-   which is never closed, and is returned for all requests.
-   
- * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
-   makes use of a file-based sqlite database.  This pool disables any actual "pooling"
-   behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
-   and :func:`close()` methods.  SQLite can "connect" to a particular file with very high 
-   efficiency, so this option may actually perform better without the extra overhead
-   of :class:`SingletonThreadPool`.  NullPool will of course render a ``:memory:`` connection
-   useless since the database would be lost as soon as the connection is "returned" to the pool.
-
-Date and Time Types
--------------------
-
-SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide 
-out of the box functionality for translating values between Python `datetime` objects
-and a SQLite-supported format.  SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
-and related types provide date formatting and parsing functionality when SQlite is used.
-The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`.
-These types represent dates and times as ISO formatted strings, which also nicely
-support ordering.   There's no reliance on typical "libc" internals for these functions
-so historical dates are fully supported.
-
-Unicode
--------
-
-In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's 
-default behavior regarding Unicode is that all strings are returned as Python unicode objects
-in all cases.  So even if the :class:`~sqlalchemy.types.Unicode` type is 
-*not* used, you will still always receive unicode data back from a result set.  It is 
-**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
-to represent strings, since it will raise a warning if a non-unicode Python string is 
-passed from the user application.  Mixing the usage of non-unicode objects with returned unicode objects can
-quickly create confusion, particularly when using the ORM as internal data is not 
-always represented by an actual database result string.
-
-"""
-
-
-import datetime, re, time
-
-from sqlalchemy import sql, schema, exc, pool, DefaultClause
-from sqlalchemy.engine import default
-import sqlalchemy.types as sqltypes
-import sqlalchemy.util as util
-from sqlalchemy.sql import compiler, functions as sql_functions
-from types import NoneType
-
-class SLNumeric(sqltypes.Numeric):
-    def bind_processor(self, dialect):
-        type_ = self.asdecimal and str or float
-        def process(value):
-            if value is not None:
-                return type_(value)
-            else:
-                return value
-        return process
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SLFloat(sqltypes.Float):
-    def bind_processor(self, dialect):
-        type_ = self.asdecimal and str or float
-        def process(value):
-            if value is not None:
-                return type_(value)
-            else:
-                return value
-        return process
-
-    def get_col_spec(self):
-        return "FLOAT"
-    
-class SLInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class SLSmallInteger(sqltypes.Smallinteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class DateTimeMixin(object):
-    def _bind_processor(self, format, elements):
-        def process(value):
-            if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
-                raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
-            elif value is not None:
-                return format % tuple([getattr(value, attr, 0) for attr in elements])
-            else:
-                return None
-        return process
-
-    def _result_processor(self, fn, regexp):
-        def process(value):
-            if value is not None:
-                return fn(*[int(x or 0) for x in regexp.match(value).groups()])
-            else:
-                return None
-        return process
-
-class SLDateTime(DateTimeMixin, sqltypes.DateTime):
-    __legacy_microseconds__ = False
-
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-    def bind_processor(self, dialect):
-        if self.__legacy_microseconds__:
-            return self._bind_processor(
-                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", 
-                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
-                        )
-        else:
-            return self._bind_processor(
-                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", 
-                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
-                        )
-
-    _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
-    def result_processor(self, dialect):
-        return self._result_processor(datetime.datetime, self._reg)
-
-class SLDate(DateTimeMixin, sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
-
-    def bind_processor(self, dialect):
-        return self._bind_processor(
-                        "%4.4d-%2.2d-%2.2d", 
-                        ("year", "month", "day")
-                )
-
-    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
-    def result_processor(self, dialect):
-        return self._result_processor(datetime.date, self._reg)
-
-class SLTime(DateTimeMixin, sqltypes.Time):
-    __legacy_microseconds__ = False
-
-    def get_col_spec(self):
-        return "TIME"
-
-    def bind_processor(self, dialect):
-        if self.__legacy_microseconds__:
-            return self._bind_processor(
-                            "%2.2d:%2.2d:%2.2d.%s", 
-                            ("hour", "minute", "second", "microsecond")
-                    )
-        else:
-            return self._bind_processor(
-                            "%2.2d:%2.2d:%2.2d.%06d", 
-                            ("hour", "minute", "second", "microsecond")
-                    )
-
-    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
-    def result_processor(self, dialect):
-        return self._result_processor(datetime.time, self._reg)
-
-class SLUnicodeMixin(object):
-    def bind_processor(self, dialect):
-        if self.convert_unicode or dialect.convert_unicode:
-            if self.assert_unicode is None:
-                assert_unicode = dialect.assert_unicode
-            else:
-                assert_unicode = self.assert_unicode
-                
-            if not assert_unicode:
-                return None
-                
-            def process(value):
-                if not isinstance(value, (unicode, NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
-                        return value
-                    else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
-            return process
-        else:
-            return None
-
-    def result_processor(self, dialect):
-        return None
-    
-class SLText(SLUnicodeMixin, sqltypes.Text):
-    def get_col_spec(self):
-        return "TEXT"
-
-class SLString(SLUnicodeMixin, sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLChar(SLUnicodeMixin, sqltypes.CHAR):
-    def get_col_spec(self):
-        return "CHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLBinary(sqltypes.Binary):
-    def get_col_spec(self):
-        return "BLOB"
-
-class SLBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BOOLEAN"
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value and 1 or 0
-        return process
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value == 1
-        return process
-
-colspecs = {
-    sqltypes.Binary: SLBinary,
-    sqltypes.Boolean: SLBoolean,
-    sqltypes.CHAR: SLChar,
-    sqltypes.Date: SLDate,
-    sqltypes.DateTime: SLDateTime,
-    sqltypes.Float: SLFloat,
-    sqltypes.Integer: SLInteger,
-    sqltypes.NCHAR: SLChar,
-    sqltypes.Numeric: SLNumeric,
-    sqltypes.Smallinteger: SLSmallInteger,
-    sqltypes.String: SLString,
-    sqltypes.Text: SLText,
-    sqltypes.Time: SLTime,
-}
-
-ischema_names = {
-    'BLOB': SLBinary,
-    'BOOL': SLBoolean,
-    'BOOLEAN': SLBoolean,
-    'CHAR': SLChar,
-    'DATE': SLDate,
-    'DATETIME': SLDateTime,
-    'DECIMAL': SLNumeric,
-    'FLOAT': SLFloat,
-    'INT': SLInteger,
-    'INTEGER': SLInteger,
-    'NUMERIC': SLNumeric,
-    'REAL': SLNumeric,
-    'SMALLINT': SLSmallInteger,
-    'TEXT': SLText,
-    'TIME': SLTime,
-    'TIMESTAMP': SLDateTime,
-    'VARCHAR': SLString,
-}
-
-class SQLiteExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self):
-        if self.compiled.isinsert and not self.executemany:
-            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
-                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
-class SQLiteDialect(default.DefaultDialect):
-    name = 'sqlite'
-    supports_alter = False
-    supports_unicode_statements = True
-    default_paramstyle = 'qmark'
-    supports_default_values = True
-    supports_empty_insert = False
-
-    def __init__(self, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-        def vers(num):
-            return tuple([int(x) for x in num.split('.')])
-        if self.dbapi is not None:
-            sqlite_ver = self.dbapi.version_info
-            if sqlite_ver < (2, 1, '3'):
-                util.warn(
-                    ("The installed version of pysqlite2 (%s) is out-dated "
-                     "and will cause errors in some cases.  Version 2.1.3 "
-                     "or greater is recommended.") %
-                    '.'.join([str(subver) for subver in sqlite_ver]))
-            if self.dbapi.sqlite_version_info < (3, 3, 8):
-                self.supports_default_values = False
-        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
-
-    def dbapi(cls):
-        try:
-            from pysqlite2 import dbapi2 as sqlite
-        except ImportError, e:
-            try:
-                from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
-            except ImportError:
-                raise e
-        return sqlite
-    dbapi = classmethod(dbapi)
-
-    def server_version_info(self, connection):
-        return self.dbapi.sqlite_version_info
-
-    def create_connect_args(self, url):
-        if url.username or url.password or url.host or url.port:
-            raise exc.ArgumentError(
-                "Invalid SQLite URL: %s\n"
-                "Valid SQLite URL forms are:\n"
-                " sqlite:///:memory: (or, sqlite://)\n"
-                " sqlite:///relative/path/to/file.db\n"
-                " sqlite:////absolute/path/to/file.db" % (url,))
-        filename = url.database or ':memory:'
-
-        opts = url.query.copy()
-        util.coerce_kw_type(opts, 'timeout', float)
-        util.coerce_kw_type(opts, 'isolation_level', str)
-        util.coerce_kw_type(opts, 'detect_types', int)
-        util.coerce_kw_type(opts, 'check_same_thread', bool)
-        util.coerce_kw_type(opts, 'cached_statements', int)
-
-        return ([filename], opts)
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
-    def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
-
-    def table_names(self, connection, schema):
-        if schema is not None:
-            qschema = self.identifier_preparer.quote_identifier(schema)
-            master = '%s.sqlite_master' % qschema
-            s = ("SELECT name FROM %s "
-                 "WHERE type='table' ORDER BY name") % (master,)
-            rs = connection.execute(s)
-        else:
-            try:
-                s = ("SELECT name FROM "
-                     " (SELECT * FROM sqlite_master UNION ALL "
-                     "  SELECT * FROM sqlite_temp_master) "
-                     "WHERE type='table' ORDER BY name")
-                rs = connection.execute(s)
-            except exc.DBAPIError:
-                raise
-                s = ("SELECT name FROM sqlite_master "
-                     "WHERE type='table' ORDER BY name")
-                rs = connection.execute(s)
-
-        return [row[0] for row in rs]
-
-    def has_table(self, connection, table_name, schema=None):
-        quote = self.identifier_preparer.quote_identifier
-        if schema is not None:
-            pragma = "PRAGMA %s." % quote(schema)
-        else:
-            pragma = "PRAGMA "
-        qtable = quote(table_name)
-        cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
-            
-        row = cursor.fetchone()
-
-        # consume remaining rows, to work around
-        # http://www.sqlite.org/cvstrac/tktview?tn=1884
-        while cursor.fetchone() is not None:
-            pass
-
-        return (row is not None)
-
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-        if table.schema is None:
-            pragma = "PRAGMA "
-        else:
-            pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema)
-        qtable = preparer.format_table(table, False)
-
-        c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
-        found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-
-            found_table = True
-            (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
-            name = re.sub(r'^\"|\"$', '', name)
-            if include_columns and name not in include_columns:
-                continue
-            match = re.match(r'(\w+)(\(.*?\))?', type_)
-            if match:
-                coltype = match.group(1)
-                args = match.group(2)
-            else:
-                coltype = "VARCHAR"
-                args = ''
-
-            try:
-                coltype = ischema_names[coltype]
-            except KeyError:
-                util.warn("Did not recognize type '%s' of column '%s'" %
-                          (coltype, name))
-                coltype = sqltypes.NullType
-
-            if args is not None:
-                args = re.findall(r'(\d+)', args)
-                coltype = coltype(*[int(a) for a in args])
-
-            colargs = []
-            if has_default:
-                colargs.append(DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
-
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
-
-        c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)))
-        fks = {}
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
-            tablename = re.sub(r'^\"|\"$', '', tablename)
-            localcol = re.sub(r'^\"|\"$', '', localcol)
-            remotecol = re.sub(r'^\"|\"$', '', remotecol)
-            try:
-                fk = fks[constraint_name]
-            except KeyError:
-                fk = ([], [])
-                fks[constraint_name] = fk
-
-            # look up the table based on the given table's engine, not 'self',
-            # since it could be a ProxyEngine
-            remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
-            constrained_column = table.c[localcol].name
-            refspec = ".".join([tablename, remotecol])
-            if constrained_column not in fk[0]:
-                fk[0].append(constrained_column)
-            if refspec not in fk[1]:
-                fk[1].append(refspec)
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True))
-        # check for UNIQUE indexes
-        c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
-        unique_indexes = []
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            if (row[2] == 1):
-                unique_indexes.append(row[1])
-        # loop thru unique indexes for one that includes the primary key
-        for idx in unique_indexes:
-            c = connection.execute("%sindex_info(%s)" % (pragma, idx))
-            cols = []
-            while True:
-                row = c.fetchone()
-                if row is None:
-                    break
-                cols.append(row[2])
-
-def _pragma_cursor(cursor):
-    if cursor.closed:
-        cursor._fetchone_impl = lambda: None
-    return cursor
-        
-class SQLiteCompiler(compiler.DefaultCompiler):
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now: 'CURRENT_TIMESTAMP',
-            sql_functions.char_length: 'length%(expr)s'
-        }
-    )
-
-    extract_map = compiler.DefaultCompiler.extract_map.copy()
-    extract_map.update({
-        'month': '%m',
-        'day': '%d',
-        'year': '%Y',
-        'second': '%S',
-        'hour': '%H',
-        'doy': '%j',
-        'minute': '%M',
-        'epoch': '%s',
-        'dow': '%w',
-        'week': '%W'
-    })
-
-    def visit_cast(self, cast, **kwargs):
-        if self.dialect.supports_cast:
-            return super(SQLiteCompiler, self).visit_cast(cast)
-        else:
-            return self.process(cast.clause)
-
-    def visit_extract(self, extract):
-        try:
-            return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
-                self.extract_map[extract.field], self.process(extract.expr))
-        except KeyError:
-            raise exc.ArgumentError(
-                "%s is not a valid extract argument." % extract.field)
-
-    def limit_clause(self, select):
-        text = ""
-        if select._limit is not None:
-            text +=  " \n LIMIT " + str(select._limit)
-        if select._offset is not None:
-            if select._limit is None:
-                text += " \n LIMIT -1"
-            text += " OFFSET " + str(select._offset)
-        else:
-            text += " OFFSET 0"
-        return text
-
-    def for_update_clause(self, select):
-        # sqlite has no "FOR UPDATE" AFAICT
-        return ''
-
-
-class SQLiteSchemaGenerator(compiler.SchemaGenerator):
-
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
-
-class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
-    reserved_words = set([
-        'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
-        'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
-        'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
-        'conflict', 'constraint', 'create', 'cross', 'current_date',
-        'current_time', 'current_timestamp', 'database', 'default',
-        'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
-        'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
-        'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
-        'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
-        'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
-        'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
-        'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
-        'plan', 'pragma', 'primary', 'query', 'raise', 'references',
-        'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
-        'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
-        'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
-        'vacuum', 'values', 'view', 'virtual', 'when', 'where', 'indexed',
-        ])
-
-    def __init__(self, dialect):
-        super(SQLiteIdentifierPreparer, self).__init__(dialect)
-
-dialect = SQLiteDialect
-dialect.poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = SQLiteCompiler
-dialect.schemagenerator = SQLiteSchemaGenerator
-dialect.preparer = SQLiteIdentifierPreparer
-dialect.execution_ctx_cls = SQLiteExecutionContext
diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py
deleted file mode 100644 (file)
index f5b48e1..0000000
+++ /dev/null
@@ -1,875 +0,0 @@
-# sybase.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-Sybase database backend.
-
-Known issues / TODO:
-
- * Uses the mx.ODBC driver from egenix (version 2.1.0)
- * The current version of sqlalchemy.databases.sybase only supports
-   mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
-   some development)
- * Support for pyodbc has been built in but is not yet complete (needs
-   further development)
- * Results of running tests/alltests.py:
-     Ran 934 tests in 287.032s
-     FAILED (failures=3, errors=1)
- * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
-"""
-
-import datetime, operator
-
-from sqlalchemy import util, sql, schema, exc
-from sqlalchemy.sql import compiler, expression
-from sqlalchemy.engine import default, base
-from sqlalchemy import types as sqltypes
-from sqlalchemy.sql import operators as sql_operators
-from sqlalchemy import MetaData, Table, Column
-from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
-
-
-__all__ = [
-    'SybaseTypeError'
-    'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger',
-    'SybaseTinyInteger', 'SybaseSmallInteger',
-    'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc',
-    'SybaseDate_mxodbc', 'SybaseDate_pyodbc',
-    'SybaseTime_mxodbc', 'SybaseTime_pyodbc',
-    'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary',
-    'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney',
-    'SybaseUniqueIdentifier',
-    ]
-
-
-RESERVED_WORDS = set([
-    "add", "all", "alter", "and",
-    "any", "as", "asc", "backup",
-    "begin", "between", "bigint", "binary",
-    "bit", "bottom", "break", "by",
-    "call", "capability", "cascade", "case",
-    "cast", "char", "char_convert", "character",
-    "check", "checkpoint", "close", "comment",
-    "commit", "connect", "constraint", "contains",
-    "continue", "convert", "create", "cross",
-    "cube", "current", "current_timestamp", "current_user",
-    "cursor", "date", "dbspace", "deallocate",
-    "dec", "decimal", "declare", "default",
-    "delete", "deleting", "desc", "distinct",
-    "do", "double", "drop", "dynamic",
-    "else", "elseif", "encrypted", "end",
-    "endif", "escape", "except", "exception",
-    "exec", "execute", "existing", "exists",
-    "externlogin", "fetch", "first", "float",
-    "for", "force", "foreign", "forward",
-    "from", "full", "goto", "grant",
-    "group", "having", "holdlock", "identified",
-    "if", "in", "index", "index_lparen",
-    "inner", "inout", "insensitive", "insert",
-    "inserting", "install", "instead", "int",
-    "integer", "integrated", "intersect", "into",
-    "iq", "is", "isolation", "join",
-    "key", "lateral", "left", "like",
-    "lock", "login", "long", "match",
-    "membership", "message", "mode", "modify",
-    "natural", "new", "no", "noholdlock",
-    "not", "notify", "null", "numeric",
-    "of", "off", "on", "open",
-    "option", "options", "or", "order",
-    "others", "out", "outer", "over",
-    "passthrough", "precision", "prepare", "primary",
-    "print", "privileges", "proc", "procedure",
-    "publication", "raiserror", "readtext", "real",
-    "reference", "references", "release", "remote",
-    "remove", "rename", "reorganize", "resource",
-    "restore", "restrict", "return", "revoke",
-    "right", "rollback", "rollup", "save",
-    "savepoint", "scroll", "select", "sensitive",
-    "session", "set", "setuser", "share",
-    "smallint", "some", "sqlcode", "sqlstate",
-    "start", "stop", "subtrans", "subtransaction",
-    "synchronize", "syntax_error", "table", "temporary",
-    "then", "time", "timestamp", "tinyint",
-    "to", "top", "tran", "trigger",
-    "truncate", "tsequal", "unbounded", "union",
-    "unique", "unknown", "unsigned", "update",
-    "updating", "user", "using", "validate",
-    "values", "varbinary", "varchar", "variable",
-    "varying", "view", "wait", "waitfor",
-    "when", "where", "while", "window",
-    "with", "with_cube", "with_lparen", "with_rollup",
-    "within", "work", "writetext",
-    ])
-
-ischema = MetaData()
-
-tables = Table("SYSTABLE", ischema,
-    Column("table_id", Integer, primary_key=True),
-    Column("file_id", SMALLINT),
-    Column("table_name", CHAR(128)),
-    Column("table_type", CHAR(10)),
-    Column("creator", Integer),
-    #schema="information_schema"
-    )
-
-domains = Table("SYSDOMAIN", ischema,
-    Column("domain_id", Integer, primary_key=True),
-    Column("domain_name", CHAR(128)),
-    Column("type_id", SMALLINT),
-    Column("precision", SMALLINT, quote=True),
-    #schema="information_schema"
-    )
-
-columns = Table("SYSCOLUMN", ischema,
-    Column("column_id", Integer, primary_key=True),
-    Column("table_id", Integer, ForeignKey(tables.c.table_id)),
-    Column("pkey", CHAR(1)),
-    Column("column_name", CHAR(128)),
-    Column("nulls", CHAR(1)),
-    Column("width", SMALLINT),
-    Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)),
-    # FIXME: should be mx.BIGINT
-    Column("max_identity", Integer),
-    # FIXME: should be mx.ODBC.Windows.LONGVARCHAR
-    Column("default", String),
-    Column("scale", Integer),
-    #schema="information_schema"
-    )
-
-foreignkeys = Table("SYSFOREIGNKEY", ischema,
-    Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True),
-    Column("foreign_key_id", SMALLINT, primary_key=True),
-    Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)),
-    #schema="information_schema"
-    )
-fkcols = Table("SYSFKCOL", ischema,
-    Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True),
-    Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True),
-    Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True),
-    Column("primary_column_id", Integer),
-    #schema="information_schema"
-    )
-
-class SybaseTypeError(sqltypes.TypeEngine):
-    def result_processor(self, dialect):
-        return None
-
-    def bind_processor(self, dialect):
-        def process(value):
-            raise exc.InvalidRequestError("Data type not supported", [value])
-        return process
-
-    def get_col_spec(self):
-        raise exc.CompileError("Data type not supported")
-
-class SybaseNumeric(sqltypes.Numeric):
-    def get_col_spec(self):
-        if self.scale is None:
-            if self.precision is None:
-                return "NUMERIC"
-            else:
-                return "NUMERIC(%(precision)s)" % {'precision' : self.precision}
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SybaseFloat(sqltypes.FLOAT, SybaseNumeric):
-    def __init__(self, precision = 10, asdecimal = False, scale = 2, **kwargs):
-        super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs)
-        self.scale = scale
-
-    def get_col_spec(self):
-        # if asdecimal is True, handle same way as SybaseNumeric
-        if self.asdecimal:
-            return SybaseNumeric.get_col_spec(self)
-        if self.precision is None:
-            return "FLOAT"
-        else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return float(value)
-        if self.asdecimal:
-            return SybaseNumeric.result_processor(self, dialect)
-        return process
-
-class SybaseInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class SybaseBigInteger(SybaseInteger):
-    def get_col_spec(self):
-        return "BIGINT"
-
-class SybaseTinyInteger(SybaseInteger):
-    def get_col_spec(self):
-        return "TINYINT"
-
-class SybaseSmallInteger(SybaseInteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class SybaseDateTime_mxodbc(sqltypes.DateTime):
-    def __init__(self, *a, **kw):
-        super(SybaseDateTime_mxodbc, self).__init__(False)
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-class SybaseDateTime_pyodbc(sqltypes.DateTime):
-    def __init__(self, *a, **kw):
-        super(SybaseDateTime_pyodbc, self).__init__(False)
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            # Convert the datetime.datetime back to datetime.time
-            return value
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value
-        return process
-
-class SybaseDate_mxodbc(sqltypes.Date):
-    def __init__(self, *a, **kw):
-        super(SybaseDate_mxodbc, self).__init__(False)
-
-    def get_col_spec(self):
-        return "DATE"
-
-class SybaseDate_pyodbc(sqltypes.Date):
-    def __init__(self, *a, **kw):
-        super(SybaseDate_pyodbc, self).__init__(False)
-
-    def get_col_spec(self):
-        return "DATE"
-
-class SybaseTime_mxodbc(sqltypes.Time):
-    def __init__(self, *a, **kw):
-        super(SybaseTime_mxodbc, self).__init__(False)
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            # Convert the datetime.datetime back to datetime.time
-            return datetime.time(value.hour, value.minute, value.second, value.microsecond)
-        return process
-
-class SybaseTime_pyodbc(sqltypes.Time):
-    def __init__(self, *a, **kw):
-        super(SybaseTime_pyodbc, self).__init__(False)
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            # Convert the datetime.datetime back to datetime.time
-            return datetime.time(value.hour, value.minute, value.second, value.microsecond)
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond)
-        return process
-
-class SybaseText(sqltypes.Text):
-    def get_col_spec(self):
-        return "TEXT"
-
-class SybaseString(sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-class SybaseChar(sqltypes.CHAR):
-    def get_col_spec(self):
-        return "CHAR(%(length)s)" % {'length' : self.length}
-
-class SybaseBinary(sqltypes.Binary):
-    def get_col_spec(self):
-        return "IMAGE"
-
-class SybaseBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BIT"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is True:
-                return 1
-            elif value is False:
-                return 0
-            elif value is None:
-                return None
-            else:
-                return value and True or False
-        return process
-
-class SybaseTimeStamp(sqltypes.TIMESTAMP):
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-class SybaseMoney(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "MONEY"
-
-class SybaseSmallMoney(SybaseMoney):
-    def get_col_spec(self):
-        return "SMALLMONEY"
-
-class SybaseUniqueIdentifier(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "UNIQUEIDENTIFIER"
-
-class SybaseSQLExecutionContext(default.DefaultExecutionContext):
-    pass
-
-class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext):
-
-    def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
-        super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-
-    def pre_exec(self):
-        super(SybaseSQLExecutionContext_mxodbc, self).pre_exec()
-
-    def post_exec(self):
-        if self.compiled.isinsert:
-            table = self.compiled.statement.table
-            # get the inserted values of the primary key
-
-            # get any sequence IDs first (using @@identity)
-            self.cursor.execute("SELECT @@identity AS lastrowid")
-            row = self.cursor.fetchone()
-            lastrowid = int(row[0])
-            if lastrowid > 0:
-                # an IDENTITY was inserted, fetch it
-                # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
-                if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
-                    self._last_inserted_ids = [lastrowid]
-                else:
-                    self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
-        super(SybaseSQLExecutionContext_mxodbc, self).post_exec()
-
-class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext):
-    def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
-        super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-
-    def pre_exec(self):
-        super(SybaseSQLExecutionContext_pyodbc, self).pre_exec()
-
-    def post_exec(self):
-        if self.compiled.isinsert:
-            table = self.compiled.statement.table
-            # get the inserted values of the primary key
-
-            # get any sequence IDs first (using @@identity)
-            self.cursor.execute("SELECT @@identity AS lastrowid")
-            row = self.cursor.fetchone()
-            lastrowid = int(row[0])
-            if lastrowid > 0:
-                # an IDENTITY was inserted, fetch it
-                # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
-                if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
-                    self._last_inserted_ids = [lastrowid]
-                else:
-                    self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
-        super(SybaseSQLExecutionContext_pyodbc, self).post_exec()
-
-class SybaseSQLDialect(default.DefaultDialect):
-    colspecs = {
-        # FIXME: unicode support
-        #sqltypes.Unicode : SybaseUnicode,
-        sqltypes.Integer : SybaseInteger,
-        sqltypes.SmallInteger : SybaseSmallInteger,
-        sqltypes.Numeric : SybaseNumeric,
-        sqltypes.Float : SybaseFloat,
-        sqltypes.String : SybaseString,
-        sqltypes.Binary : SybaseBinary,
-        sqltypes.Boolean : SybaseBoolean,
-        sqltypes.Text : SybaseText,
-        sqltypes.CHAR : SybaseChar,
-        sqltypes.TIMESTAMP : SybaseTimeStamp,
-        sqltypes.FLOAT : SybaseFloat,
-    }
-
-    ischema_names = {
-        'integer' : SybaseInteger,
-        'unsigned int' : SybaseInteger,
-        'unsigned smallint' : SybaseInteger,
-        'unsigned bigint' : SybaseInteger,
-        'bigint': SybaseBigInteger,
-        'smallint' : SybaseSmallInteger,
-        'tinyint' : SybaseTinyInteger,
-        'varchar' : SybaseString,
-        'long varchar' : SybaseText,
-        'char' : SybaseChar,
-        'decimal' : SybaseNumeric,
-        'numeric' : SybaseNumeric,
-        'float' : SybaseFloat,
-        'double' : SybaseFloat,
-        'binary' : SybaseBinary,
-        'long binary' : SybaseBinary,
-        'varbinary' : SybaseBinary,
-        'bit': SybaseBoolean,
-        'image' : SybaseBinary,
-        'timestamp': SybaseTimeStamp,
-        'money': SybaseMoney,
-        'smallmoney': SybaseSmallMoney,
-        'uniqueidentifier': SybaseUniqueIdentifier,
-
-        'java.lang.Object' : SybaseTypeError,
-        'java serialization' : SybaseTypeError,
-    }
-
-    name = 'sybase'
-    # Sybase backend peculiarities
-    supports_unicode_statements = False
-    supports_sane_rowcount = False
-    supports_sane_multi_rowcount = False
-    execution_ctx_cls = SybaseSQLExecutionContext
-    
-    def __new__(cls, dbapi=None, *args, **kwargs):
-        if cls != SybaseSQLDialect:
-            return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs)
-        if dbapi:
-            print dbapi.__name__
-            dialect = dialect_mapping.get(dbapi.__name__)
-            return dialect(*args, **kwargs)
-        else:
-            return object.__new__(cls, *args, **kwargs)
-
-    def __init__(self, **params):
-        super(SybaseSQLDialect, self).__init__(**params)
-        self.text_as_varchar = False
-        # FIXME: what is the default schema for sybase connections (DBA?) ?
-        self.set_default_schema_name("dba")
-
-    def dbapi(cls, module_name=None):
-        if module_name:
-            try:
-                dialect_cls = dialect_mapping[module_name]
-                return dialect_cls.import_dbapi()
-            except KeyError:
-                raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
-        else:
-            for dialect_cls in dialect_mapping.values():
-                try:
-                    return dialect_cls.import_dbapi()
-                except ImportError, e:
-                    pass
-            else:
-                raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc')
-    dbapi = classmethod(dbapi)
-
-    def type_descriptor(self, typeobj):
-        newobj = sqltypes.adapt_type(typeobj, self.colspecs)
-        return newobj
-
-    def last_inserted_ids(self):
-        return self.context.last_inserted_ids
-
-    def get_default_schema_name(self, connection):
-        return self.schema_name
-
-    def set_default_schema_name(self, schema_name):
-        self.schema_name = schema_name
-
-    def do_execute(self, cursor, statement, params, **kwargs):
-        params = tuple(params)
-        super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs)
-
-    # FIXME: remove ?
-    def _execute(self, c, statement, parameters):
-        try:
-            if parameters == {}:
-                parameters = ()
-            c.execute(statement, parameters)
-            self.context.rowcount = c.rowcount
-            c.DBPROP_COMMITPRESERVE = "Y"
-        except Exception, e:
-            raise exc.DBAPIError.instance(statement, parameters, e)
-
-    def table_names(self, connection, schema):
-        """Ignore the schema and the charset for now."""
-        s = sql.select([tables.c.table_name],
-                       sql.not_(tables.c.table_name.like("SYS%")) and
-                       tables.c.creator >= 100
-                       )
-        rp = connection.execute(s)
-        return [row[0] for row in rp.fetchall()]
-
-    def has_table(self, connection, tablename, schema=None):
-        # FIXME: ignore schemas for sybase
-        s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
-
-        c = connection.execute(s)
-        row = c.fetchone()
-        print "has_table: " + tablename + ": " + str(bool(row is not None))
-        return row is not None
-
-    def reflecttable(self, connection, table, include_columns):
-        # Get base columns
-        if table.schema is not None:
-            current_schema = table.schema
-        else:
-            current_schema = self.get_default_schema_name(connection)
-
-        s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
-
-        c = connection.execute(s)
-        found_table = False
-        # makes sure we append the columns in the correct order
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            found_table = True
-            (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
-                row[columns.c.column_name],
-                row[domains.c.domain_name],
-                row[columns.c.nulls] == 'Y',
-                row[columns.c.width],
-                row[domains.c.precision],
-                row[columns.c.scale],
-                row[columns.c.default],
-                row[columns.c.pkey] == 'Y',
-                row[columns.c.max_identity],
-                row[tables.c.table_id],
-                row[columns.c.column_id],
-            )
-            if include_columns and name not in include_columns:
-                continue
-
-            # FIXME: else problems with SybaseBinary(size)
-            if numericscale == 0:
-                numericscale = None
-
-            args = []
-            for a in (charlen, numericprec, numericscale):
-                if a is not None:
-                    args.append(a)
-            coltype = self.ischema_names.get(type, None)
-            if coltype == SybaseString and charlen == -1:
-                coltype = SybaseText()
-            else:
-                if coltype is None:
-                    util.warn("Did not recognize type '%s' of column '%s'" %
-                              (type, name))
-                    coltype = sqltypes.NULLTYPE
-                coltype = coltype(*args)
-            colargs = []
-            if default is not None:
-                colargs.append(schema.DefaultClause(sql.text(default)))
-
-            # any sequences ?
-            col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
-            if int(max_identity) > 0:
-                col.sequence = schema.Sequence(name + '_identity')
-                col.sequence.start = int(max_identity)
-                col.sequence.increment = 1
-
-            # append the column
-            table.append_column(col)
-
-        # any foreign key constraint for this table ?
-        # note: no multi-column foreign keys are considered
-        s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
-        c = connection.execute(s)
-        foreignKeys = {}
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            (foreign_table, foreign_column, primary_table, primary_column) = (
-                row[0], row[1], row[2], row[3],
-            )
-            if not primary_table in foreignKeys.keys():
-                foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
-            else:
-                foreignKeys[primary_table][0].append('%s'%(foreign_column))
-                foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
-        for primary_table in foreignKeys.keys():
-            #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
-            table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
-
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
-
-
-class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
-    execution_ctx_cls = SybaseSQLExecutionContext_mxodbc
-    
-    def __init__(self, **params):
-        super(SybaseSQLDialect_mxodbc, self).__init__(**params)
-
-        self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()}
-
-    def import_dbapi(cls):
-        #import mx.ODBC.Windows as module
-        import mxODBC as module
-        return module
-    import_dbapi = classmethod(import_dbapi)
-
-    colspecs = SybaseSQLDialect.colspecs.copy()
-    colspecs[sqltypes.Time] = SybaseTime_mxodbc
-    colspecs[sqltypes.Date] = SybaseDate_mxodbc
-    colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc
-
-    ischema_names = SybaseSQLDialect.ischema_names.copy()
-    ischema_names['time'] = SybaseTime_mxodbc
-    ischema_names['date'] = SybaseDate_mxodbc
-    ischema_names['datetime'] = SybaseDateTime_mxodbc
-    ischema_names['smalldatetime'] = SybaseDateTime_mxodbc
-
-    def is_disconnect(self, e):
-        # FIXME: optimize
-        #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
-        #return True
-        return False
-
-    def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
-        super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-
-    def create_connect_args(self, url):
-        '''Return a tuple of *args,**kwargs'''
-        # FIXME: handle mx.odbc.Windows proprietary args
-        opts = url.translate_connect_args(username='user')
-        opts.update(url.query)
-        argsDict = {}
-        argsDict['user'] = opts['user']
-        argsDict['password'] = opts['password']
-        connArgs = [[opts['dsn']], argsDict]
-        return connArgs
-
-
-class SybaseSQLDialect_pyodbc(SybaseSQLDialect):
-    execution_ctx_cls = SybaseSQLExecutionContext_pyodbc
-    
-    def __init__(self, **params):
-        super(SybaseSQLDialect_pyodbc, self).__init__(**params)
-        self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()}
-
-    def import_dbapi(cls):
-        import mypyodbc as module
-        return module
-    import_dbapi = classmethod(import_dbapi)
-
-    colspecs = SybaseSQLDialect.colspecs.copy()
-    colspecs[sqltypes.Time] = SybaseTime_pyodbc
-    colspecs[sqltypes.Date] = SybaseDate_pyodbc
-    colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc
-
-    ischema_names = SybaseSQLDialect.ischema_names.copy()
-    ischema_names['time'] = SybaseTime_pyodbc
-    ischema_names['date'] = SybaseDate_pyodbc
-    ischema_names['datetime'] = SybaseDateTime_pyodbc
-    ischema_names['smalldatetime'] = SybaseDateTime_pyodbc
-
-    def is_disconnect(self, e):
-        # FIXME: optimize
-        #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
-        #return True
-        return False
-
-    def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
-        super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-
-    def create_connect_args(self, url):
-        '''Return a tuple of *args,**kwargs'''
-        # FIXME: handle pyodbc proprietary args
-        opts = url.translate_connect_args(username='user')
-        opts.update(url.query)
-
-        self.autocommit = False
-        if 'autocommit' in opts:
-            self.autocommit = bool(int(opts.pop('autocommit')))
-
-        argsDict = {}
-        argsDict['UID'] = opts['user']
-        argsDict['PWD'] = opts['password']
-        argsDict['DSN'] = opts['dsn']
-        connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}]
-        return connArgs
-
-
-dialect_mapping = {
-    'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc,
-#    'pyodbc' : SybaseSQLDialect_pyodbc,
-    }
-
-
-class SybaseSQLCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
-    operators.update({
-        sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
-    })
-
-    extract_map = compiler.DefaultCompiler.extract_map.copy()
-    extract_map.update ({
-        'doy': 'dayofyear',
-        'dow': 'weekday',
-        'milliseconds': 'millisecond'
-    })
-
-
-    def bindparam_string(self, name):
-        res = super(SybaseSQLCompiler, self).bindparam_string(name)
-        if name.lower().startswith('literal'):
-            res = 'STRING(%s)' % res
-        return res
-
-    def get_select_precolumns(self, select):
-        s = select._distinct and "DISTINCT " or ""
-        if select._limit:
-            #if select._limit == 1:
-                #s += "FIRST "
-            #else:
-                #s += "TOP %s " % (select._limit,)
-            s += "TOP %s " % (select._limit,)
-        if select._offset:
-            if not select._limit:
-                # FIXME: sybase doesn't allow an offset without a limit
-                # so use a huge value for TOP here
-                s += "TOP 1000000 "
-            s += "START AT %s " % (select._offset+1,)
-        return s
-
-    def limit_clause(self, select):
-        # Limit in sybase is after the select keyword
-        return ""
-
-    def visit_binary(self, binary):
-        """Move bind parameters to the right-hand side of an operator, where possible."""
-        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
-            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
-        else:
-            return super(SybaseSQLCompiler, self).visit_binary(binary)
-
-    def label_select_column(self, select, column, asfrom):
-        if isinstance(column, expression.Function):
-            return column.label(None)
-        else:
-            return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
-
-    function_rewrites =  {'current_date': 'getdate',
-                         }
-    def visit_function(self, func):
-        func.name = self.function_rewrites.get(func.name, func.name)
-        res = super(SybaseSQLCompiler, self).visit_function(func)
-        if func.name.lower() == 'getdate':
-            # apply CAST operator
-            # FIXME: what about _pyodbc ?
-            cast = expression._Cast(func, SybaseDate_mxodbc)
-            # infinite recursion
-            # res = self.visit_cast(cast)
-            res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
-        return res
-
-    def visit_extract(self, extract):
-        field = self.extract_map.get(extract.field, extract.field)
-        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
-
-    def for_update_clause(self, select):
-        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
-        return ''
-
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
-
-        # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
-        if order_by and (not self.is_subquery() or select._limit):
-            return " ORDER BY " + order_by
-        else:
-            return ""
-
-
-class SybaseSQLSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-
-        colspec = self.preparer.format_column(column)
-
-        if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
-                column.autoincrement and isinstance(column.type, sqltypes.Integer):
-            if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
-                column.sequence = schema.Sequence(column.name + '_seq')
-
-        if hasattr(column, 'sequence'):
-            column.table.has_sequence = column
-            #colspec += " numeric(30,0) IDENTITY"
-            colspec += " Integer IDENTITY"
-        else:
-            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        return colspec
-
-
-class SybaseSQLSchemaDropper(compiler.SchemaDropper):
-    def visit_index(self, index):
-        self.append("\nDROP INDEX %s.%s" % (
-            self.preparer.quote_identifier(index.table.name),
-            self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
-            ))
-        self.execute()
-
-
-class SybaseSQLDefaultRunner(base.DefaultRunner):
-    pass
-
-
-class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer):
-    reserved_words = RESERVED_WORDS
-
-    def __init__(self, dialect):
-        super(SybaseSQLIdentifierPreparer, self).__init__(dialect)
-
-    def _escape_identifier(self, value):
-        #TODO: determin SybaseSQL's escapeing rules
-        return value
-
-    def _fold_identifier_case(self, value):
-        #TODO: determin SybaseSQL's case folding rules
-        return value
-
-
-dialect = SybaseSQLDialect
-dialect.statement_compiler = SybaseSQLCompiler
-dialect.schemagenerator = SybaseSQLSchemaGenerator
-dialect.schemadropper = SybaseSQLSchemaDropper
-dialect.preparer = SybaseSQLIdentifierPreparer
-dialect.defaultrunner = SybaseSQLDefaultRunner
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py
new file mode 100644 (file)
index 0000000..91ca91f
--- /dev/null
@@ -0,0 +1,12 @@
+__all__ = (
+#    'access',
+#    'firebird',
+#    'informix',
+#    'maxdb',
+#    'mssql',
+    'mysql',
+    'oracle',
+    'postgresql',
+    'sqlite',
+#    'sybase',
+    )
diff --git a/lib/sqlalchemy/dialects/access/__init__.py b/lib/sqlalchemy/dialects/access/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
similarity index 96%
rename from lib/sqlalchemy/databases/access.py
rename to lib/sqlalchemy/dialects/access/base.py
index 56c28b8cc612152c006995db68de4ed3f252c006..ed8297137a2ed3921f17baa74b0eddb940c0aeba 100644 (file)
@@ -5,6 +5,13 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+"""
+Support for the Microsoft Access database.
+
+This dialect is *not* tested on SQLAlchemy 0.6.
+
+
+"""
 from sqlalchemy import sql, schema, types, exc, pool
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, base
@@ -46,7 +53,7 @@ class AcTinyInteger(types.Integer):
     def get_col_spec(self):
         return "TINYINT"
 
-class AcSmallInteger(types.Smallinteger):
+class AcSmallInteger(types.SmallInteger):
     def get_col_spec(self):
         return "SMALLINT"
 
@@ -155,7 +162,7 @@ class AccessDialect(default.DefaultDialect):
     colspecs = {
         types.Unicode : AcUnicode,
         types.Integer : AcInteger,
-        types.Smallinteger: AcSmallInteger,
+        types.SmallInteger: AcSmallInteger,
         types.Numeric : AcNumeric,
         types.Float : AcFloat,
         types.DateTime : AcDateTime,
@@ -327,8 +334,8 @@ class AccessDialect(default.DefaultDialect):
         return names
 
 
-class AccessCompiler(compiler.DefaultCompiler):
-    extract_map = compiler.DefaultCompiler.extract_map.copy()
+class AccessCompiler(compiler.SQLCompiler):
+    extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update ({
             'month': 'm',
             'day': 'd',
@@ -341,7 +348,7 @@ class AccessCompiler(compiler.DefaultCompiler):
             'dow': 'w',
             'week': 'ww'
     })
-
+        
     def visit_select_precolumns(self, select):
         """Access puts TOP, it's version of LIMIT here """
         s = select.distinct and "DISTINCT " or ""
@@ -393,8 +400,7 @@ class AccessCompiler(compiler.DefaultCompiler):
         field = self.extract_map.get(extract.field, extract.field)
         return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
 
-
-class AccessSchemaGenerator(compiler.SchemaGenerator):
+class AccessDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
@@ -417,14 +423,9 @@ class AccessSchemaGenerator(compiler.SchemaGenerator):
 
         return colspec
 
-class AccessSchemaDropper(compiler.SchemaDropper):
-    def visit_index(self, index):
-        
+    def visit_drop_index(self, drop):
+        index = drop.element
         self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False)))
-        self.execute()
-
-class AccessDefaultRunner(base.DefaultRunner):
-    pass
 
 class AccessIdentifierPreparer(compiler.IdentifierPreparer):
     reserved_words = compiler.RESERVED_WORDS.copy()
@@ -436,8 +437,6 @@ class AccessIdentifierPreparer(compiler.IdentifierPreparer):
 dialect = AccessDialect
 dialect.poolclass = pool.SingletonThreadPool
 dialect.statement_compiler = AccessCompiler
-dialect.schemagenerator = AccessSchemaGenerator
-dialect.schemadropper = AccessSchemaDropper
+dialect.ddlcompiler = AccessDDLCompiler
 dialect.preparer = AccessIdentifierPreparer
-dialect.defaultrunner = AccessDefaultRunner
-dialect.execution_ctx_cls = AccessExecutionContext
+dialect.execution_ctx_cls = AccessExecutionContext
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py
new file mode 100644 (file)
index 0000000..6b1b80d
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.firebird import base, kinterbasdb
+
+base.dialect = kinterbasdb.dialect
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
new file mode 100644 (file)
index 0000000..57b89ed
--- /dev/null
@@ -0,0 +1,626 @@
+# firebird.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+Support for the Firebird database.
+
+Connectivity is usually supplied via the kinterbasdb_
+DBAPI module.
+
+Firebird dialects
+-----------------
+
+Firebird offers two distinct dialects_ (not to be confused with a
+SQLAlchemy ``Dialect``):
+
+dialect 1
+  This is the old syntax and behaviour, inherited from Interbase pre-6.0.
+
+dialect 3
+  This is the newer and supported syntax, introduced in Interbase 6.0.
+
+The SQLAlchemy Firebird dialect detects these versions and
+adjusts its representation of SQL accordingly.  However,
+support for dialect 1 is not well tested and probably has
+incompatibilities.
+
+Firebird Locking Behavior
+-------------------------
+
+Firebird locks tables aggressively.  For this reason, a DROP TABLE may
+hang until other transactions are released.  SQLAlchemy does its best
+to release transactions as quickly as possible.  The most common cause
+of hanging transactions is a non-fully consumed result set, i.e.::
+
+    result = engine.execute("select * from table")
+    row = result.fetchone()
+    return
+
+Where above, the ``ResultProxy`` has not been fully consumed.  The
+connection will be returned to the pool and the transactional state
+rolled back once the Python garbage collector reclaims the objects
+which hold onto the connection, which often occurs asynchronously.
+The above use case can be alleviated by calling ``first()`` on the
+``ResultProxy`` which will fetch the first row and immediately close
+all remaining cursor/connection resources.
+
+RETURNING support
+-----------------
+
+Firebird 2.0 supports returning a result set from inserts, and 2.1 extends
+that to deletes and updates.
+
+To use this pass the column/expression list to the ``firebird_returning``
+parameter when creating the queries::
+
+  raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
+                      firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
+
+
+.. [#] Well, that is not the whole story, as the client may still ask
+       a different (lower) dialect...
+
+.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
+.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb
+
+"""
+
+
+import datetime, decimal, re
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import exc, types as sqltypes, sql, util
+from sqlalchemy.sql import expression
+from sqlalchemy.engine import base, default, reflection
+from sqlalchemy.sql import compiler
+
+from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
+                               FLOAT, INTEGER, NUMERIC, SMALLINT,
+                               TEXT, TIME, TIMESTAMP, VARCHAR)
+
+
+RESERVED_WORDS = set(
+   ["action", "active", "add", "admin", "after", "all", "alter", "and", "any",
+    "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename",
+    "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer",
+    "by", "cache", "cascade", "case", "cast", "char", "character", "character_length",
+    "char_length", "check", "check_point_len", "check_point_length", "close", "collate",
+    "collation", "column", "commit", "committed", "compiletime", "computed", "conditional",
+    "connect", "constraint", "containing", "continue", "count", "create", "cstring",
+    "current", "current_connection", "current_date", "current_role", "current_time",
+    "current_timestamp", "current_transaction", "current_user", "cursor", "database",
+    "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete",
+    "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct",
+    "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point",
+    "escape", "event", "exception", "execute", "exists", "exit", "extern", "external",
+    "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it",
+    "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto",
+    "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour",
+    "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input",
+    "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join",
+    "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile",
+    "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment",
+    "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month",
+    "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric",
+    "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option",
+    "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength",
+    "pages", "page_size", "parameter", "password", "plan", "position", "post_event",
+    "precision", "prepare", "primary", "privileges", "procedure", "protected", "public",
+    "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate",
+    "references", "release", "release", "reserv", "reserving", "restrict", "retain",
+    "return", "returning_values", "returns", "revoke", "right", "role", "rollback",
+    "row_count", "runtime", "savepoint", "schema", "second", "segment", "select",
+    "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint",
+    "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability",
+    "starting", "starts", "statement", "static", "statistics", "sub_type", "sum",
+    "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction",
+    "translate", "translation", "trigger", "trim", "type", "uncommitted", "union",
+    "unique", "update", "upper", "user", "using", "value", "values", "varchar",
+    "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
+    "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
+
+
+class _FBBoolean(sqltypes.Boolean):
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+
+
+colspecs = {
+    sqltypes.Boolean: _FBBoolean,
+}
+
+ischema_names = {
+      'SHORT': SMALLINT,
+       'LONG': BIGINT,
+       'QUAD': FLOAT,
+      'FLOAT': FLOAT,
+       'DATE': DATE,
+       'TIME': TIME,
+       'TEXT': TEXT,
+      'INT64': NUMERIC,
+     'DOUBLE': FLOAT,
+  'TIMESTAMP': TIMESTAMP,
+    'VARYING': VARCHAR,
+    'CSTRING': CHAR,
+       'BLOB': BLOB,
+    }
+
+
+# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
+# as bind/result functionality is required)
+
+class FBTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_boolean(self, type_):
+        return self.visit_SMALLINT(type_)
+
+    def visit_datetime(self, type_):
+        return self.visit_TIMESTAMP(type_)
+
+    def visit_TEXT(self, type_):
+        return "BLOB SUB_TYPE 1"
+
+    def visit_BLOB(self, type_):
+        return "BLOB SUB_TYPE 0"
+
+
+class FBCompiler(sql.compiler.SQLCompiler):
+    """Firebird specific idiosincrasies"""
+
+    def visit_mod(self, binary, **kw):
+        # Firebird lacks a builtin modulo operator, but there is
+        # an equivalent function in the ib_udf library.
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+    def visit_alias(self, alias, asfrom=False, **kwargs):
+        if self.dialect._version_two:
+            return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs)
+        else:
+            # Override to not use the AS keyword which FB 1.5 does not like
+            if asfrom:
+                alias_name = isinstance(alias.name, expression._generated_label) and \
+                                self._truncated_identifier("alias", alias.name) or alias.name
+
+                return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \
+                            self.preparer.format_alias(alias, alias_name)
+            else:
+                return self.process(alias.original, **kwargs)
+
+    def visit_substring_func(self, func, **kw):
+        s = self.process(func.clauses.clauses[0])
+        start = self.process(func.clauses.clauses[1])
+        if len(func.clauses.clauses) > 2:
+            length = self.process(func.clauses.clauses[2])
+            return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+        else:
+            return "SUBSTRING(%s FROM %s)" % (s, start)
+
+    def visit_length_func(self, function, **kw):
+        if self.dialect._version_two:
+            return "char_length" + self.function_argspec(function)
+        else:
+            return "strlen" + self.function_argspec(function)
+
+    visit_char_length_func = visit_length_func
+
+    def function_argspec(self, func, **kw):
+        if func.clauses:
+            return self.process(func.clause_expr)
+        else:
+            return ""
+
+    def default_from(self):
+        return " FROM rdb$database"
+
+    def visit_sequence(self, seq):
+        return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
+
+    def get_select_precolumns(self, select):
+        """Called when building a ``SELECT`` statement, position is just
+        before column list Firebird puts the limit and offset right
+        after the ``SELECT``...
+        """
+
+        result = ""
+        if select._limit:
+            result += "FIRST %d "  % select._limit
+        if select._offset:
+            result +="SKIP %d "  %  select._offset
+        if select._distinct:
+            result += "DISTINCT "
+        return result
+
+    def limit_clause(self, select):
+        """Already taken care of in the `get_select_precolumns` method."""
+
+        return ""
+
+    def returning_clause(self, stmt, returning_cols):
+
+        columns = [
+                self.process(
+                    self.label_select_column(None, c, asfrom=False), 
+                    within_columns_clause=True, 
+                    result_map=self.result_map
+                ) 
+                for c in expression._select_iterables(returning_cols)
+            ]
+        return 'RETURNING ' + ', '.join(columns)
+
+
+class FBDDLCompiler(sql.compiler.DDLCompiler):
+    """Firebird syntactic idiosincrasies"""
+
+    def visit_create_sequence(self, create):
+        """Generate a ``CREATE GENERATOR`` statement for the sequence."""
+
+        if self.dialect._version_two:
+            return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+        else:
+            return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element)
+
+    def visit_drop_sequence(self, drop):
+        """Generate a ``DROP GENERATOR`` statement for the sequence."""
+
+        if self.dialect._version_two:
+            return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+        else:
+            return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
+
+
+class FBDefaultRunner(base.DefaultRunner):
+    """Firebird specific idiosincrasies"""
+
+    def visit_sequence(self, seq):
+        """Get the next value from the sequence using ``gen_id()``."""
+
+        return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
+            self.dialect.identifier_preparer.format_sequence(seq))
+
+
+class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
+    """Install Firebird specific reserved words."""
+
+    reserved_words = RESERVED_WORDS
+
+    def __init__(self, dialect):
+        super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
+
+
+class FBDialect(default.DefaultDialect):
+    """Firebird dialect"""
+
+    name = 'firebird'
+
+    max_identifier_length = 31
+    
+    supports_sequences = True
+    sequences_optional = False
+    supports_default_values = True
+    postfetch_lastrowid = False
+    
+    requires_name_normalize = True
+    supports_empty_insert = False
+
+    statement_compiler = FBCompiler
+    ddl_compiler = FBDDLCompiler
+    defaultrunner = FBDefaultRunner
+    preparer = FBIdentifierPreparer
+    type_compiler = FBTypeCompiler
+
+    colspecs = colspecs
+    ischema_names = ischema_names
+
+    # defaults to dialect ver. 3,
+    # will be autodetected off upon
+    # first connect
+    _version_two = True
+
+    def initialize(self, connection):
+        super(FBDialect, self).initialize(connection)
+        self._version_two = self.server_version_info > (2, )
+        if not self._version_two:
+            # TODO: whatever other pre < 2.0 stuff goes here
+            self.ischema_names = ischema_names.copy()
+            self.ischema_names['TIMESTAMP'] = sqltypes.DATE
+            self.colspecs = {
+                sqltypes.DateTime: sqltypes.DATE
+            }
+        else:
+            self.implicit_returning = True
+            
+    def normalize_name(self, name):
+        # Remove trailing spaces: FB uses a CHAR() type,
+        # that is padded with spaces
+        name = name and name.rstrip()
+        if name is None:
+            return None
+        elif name.upper() == name and \
+            not self.identifier_preparer._requires_quotes(name.lower()):
+            return name.lower()
+        else:
+            return name
+
+    def denormalize_name(self, name):
+        if name is None:
+            return None
+        elif name.lower() == name and \
+            not self.identifier_preparer._requires_quotes(name.lower()):
+            return name.upper()
+        else:
+            return name
+
+    def has_table(self, connection, table_name, schema=None):
+        """Return ``True`` if the given table exists, ignoring the `schema`."""
+
+        tblqry = """
+        SELECT 1 FROM rdb$database
+        WHERE EXISTS (SELECT rdb$relation_name
+                      FROM rdb$relations
+                      WHERE rdb$relation_name=?)
+        """
+        c = connection.execute(tblqry, [self.denormalize_name(table_name)])
+        return c.first() is not None
+
+    def has_sequence(self, connection, sequence_name):
+        """Return ``True`` if the given sequence (generator) exists."""
+
+        genqry = """
+        SELECT 1 FROM rdb$database
+        WHERE EXISTS (SELECT rdb$generator_name
+                      FROM rdb$generators
+                      WHERE rdb$generator_name=?)
+        """
+        c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
+        return c.first() is not None
+
+    def table_names(self, connection, schema):
+        s = """
+        SELECT DISTINCT rdb$relation_name
+        FROM rdb$relation_fields
+        WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
+        """
+        return [self.normalize_name(row[0]) for row in connection.execute(s)]
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        return self.table_names(connection, schema)
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        s = """
+        SELECT distinct rdb$view_name
+        FROM rdb$view_relations
+        """
+        return [self.normalize_name(row[0]) for row in connection.execute(s)]
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        qry = """
+        SELECT rdb$view_source AS view_source
+        FROM rdb$relations
+        WHERE rdb$relation_name=?
+        """
+        rp = connection.execute(qry, [self.denormalize_name(view_name)])
+        row = rp.first()
+        if row:
+            return row['view_source']
+        else:
+            return None
+
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        # Query to extract the PK/FK constrained fields of the given table
+        keyqry = """
+        SELECT se.rdb$field_name AS fname
+        FROM rdb$relation_constraints rc
+             JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
+        WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+        """
+        tablename = self.denormalize_name(table_name)
+        # get primary key fields
+        c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
+        pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
+        return pkfields
+
+    @reflection.cache
+    def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw):
+        tablename = self.denormalize_name(table_name)
+        colname = self.denormalize_name(column_name)
+        # Heuristic-query to determine the generator associated to a PK field
+        genqry = """
+        SELECT trigdep.rdb$depended_on_name AS fgenerator
+        FROM rdb$dependencies tabdep
+             JOIN rdb$dependencies trigdep
+                  ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
+                     AND trigdep.rdb$depended_on_type=14
+                     AND trigdep.rdb$dependent_type=2
+             JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name
+        WHERE tabdep.rdb$depended_on_name=?
+          AND tabdep.rdb$depended_on_type=0
+          AND trig.rdb$trigger_type=1
+          AND tabdep.rdb$field_name=?
+          AND (SELECT count(*)
+               FROM rdb$dependencies trigdep2
+               WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
+        """
+        genc = connection.execute(genqry, [tablename, colname])
+        genr = genc.fetchone()
+        if genr is not None:
+            return dict(name=self.normalize_name(genr['fgenerator']))
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        # Query to extract the details of all the fields of the given table
+        tblqry = """
+        SELECT DISTINCT r.rdb$field_name AS fname,
+                        r.rdb$null_flag AS null_flag,
+                        t.rdb$type_name AS ftype,
+                        f.rdb$field_sub_type AS stype,
+                        f.rdb$field_length AS flen,
+                        f.rdb$field_precision AS fprec,
+                        f.rdb$field_scale AS fscale,
+                        COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
+        FROM rdb$relation_fields r
+             JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
+             JOIN rdb$types t
+                  ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
+        WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
+        ORDER BY r.rdb$field_position
+        """
+        # get the PK, used to determine the eventual associated sequence
+        pkey_cols = self.get_primary_keys(connection, table_name)
+
+        tablename = self.denormalize_name(table_name)
+        # get all of the fields for this table
+        c = connection.execute(tblqry, [tablename])
+        cols = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            name = self.normalize_name(row['fname'])
+            # get the data type
+
+            colspec = row['ftype'].rstrip()
+            coltype = self.ischema_names.get(colspec)
+            if coltype is None:
+                util.warn("Did not recognize type '%s' of column '%s'" %
+                          (colspec, name))
+                coltype = sqltypes.NULLTYPE
+            elif colspec == 'INT64':
+                coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1)
+            elif colspec in ('VARYING', 'CSTRING'):
+                coltype = coltype(row['flen'])
+            elif colspec == 'TEXT':
+                coltype = TEXT(row['flen'])
+            elif colspec == 'BLOB':
+                if row['stype'] == 1:
+                    coltype = TEXT()
+                else:
+                    coltype = BLOB()
+            else:
+                coltype = coltype(row)
+
+            # does it have a default value?
+            defvalue = None
+            if row['fdefault'] is not None:
+                # the value comes down as "DEFAULT 'value'"
+                assert row['fdefault'].upper().startswith('DEFAULT '), row
+                defvalue = row['fdefault'][8:]
+            col_d = {
+                'name' : name,
+                'type' : coltype,
+                'nullable' :  not bool(row['null_flag']),
+                'default' : defvalue
+            }
+
+            # if the PK is a single field, try to see if its linked to
+            # a sequence thru a trigger
+            if len(pkey_cols)==1 and name==pkey_cols[0]:
+                seq_d = self.get_column_sequence(connection, tablename, name)
+                if seq_d is not None:
+                    col_d['sequence'] = seq_d
+
+            cols.append(col_d)
+        return cols
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+        # Query to extract the details of each UK/FK of the given table
+        fkqry = """
+        SELECT rc.rdb$constraint_name AS cname,
+               cse.rdb$field_name AS fname,
+               ix2.rdb$relation_name AS targetrname,
+               se.rdb$field_name AS targetfname
+        FROM rdb$relation_constraints rc
+             JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
+             JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
+             JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
+             JOIN rdb$index_segments se
+                  ON se.rdb$index_name=ix2.rdb$index_name
+                     AND se.rdb$field_position=cse.rdb$field_position
+        WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+        ORDER BY se.rdb$index_name, se.rdb$field_position
+        """
+        tablename = self.denormalize_name(table_name)
+
+        c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
+        fks = util.defaultdict(lambda:{
+            'name' : None,
+            'constrained_columns' : [],
+            'referred_schema' : None,
+            'referred_table' : None,
+            'referred_columns' : []
+        })
+
+        for row in c:
+            cname = self.normalize_name(row['cname'])
+            fk = fks[cname]
+            if not fk['name']:
+                fk['name'] = cname
+                fk['referred_table'] = self.normalize_name(row['targetrname'])
+            fk['constrained_columns'].append(self.normalize_name(row['fname']))
+            fk['referred_columns'].append(
+                            self.normalize_name(row['targetfname']))
+        return fks.values()
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None, **kw):
+        qry = """
+        SELECT ix.rdb$index_name AS index_name,
+               ix.rdb$unique_flag AS unique_flag,
+               ic.rdb$field_name AS field_name
+        FROM rdb$indices ix
+             JOIN rdb$index_segments ic
+                  ON ix.rdb$index_name=ic.rdb$index_name
+             LEFT OUTER JOIN rdb$relation_constraints
+                  ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name
+        WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
+          AND rdb$relation_constraints.rdb$constraint_type IS NULL
+        ORDER BY index_name, field_name
+        """
+        c = connection.execute(qry, [self.denormalize_name(table_name)])
+
+        indexes = util.defaultdict(dict)
+        for row in c:
+            indexrec = indexes[row['index_name']]
+            if 'name' not in indexrec:
+                indexrec['name'] = self.normalize_name(row['index_name'])
+                indexrec['column_names'] = []
+                indexrec['unique'] = bool(row['unique_flag'])
+
+            indexrec['column_names'].append(self.normalize_name(row['field_name']))
+
+        return indexes.values()
+
+    def do_execute(self, cursor, statement, parameters, **kwargs):
+        # kinterbase does not accept a None, but wants an empty list
+        # when there are no arguments.
+        cursor.execute(statement, parameters or [])
+
+    def do_rollback(self, connection):
+        # Use the retaining feature, that keeps the transaction going
+        connection.rollback(True)
+
+    def do_commit(self, connection):
+        # Use the retaining feature, that keeps the transaction going
+        connection.commit(True)
diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
new file mode 100644 (file)
index 0000000..7d30f87
--- /dev/null
@@ -0,0 +1,70 @@
+from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
+from sqlalchemy.engine.default import DefaultExecutionContext
+
+_initialized_kb  = False
+
+class Firebird_kinterbasdb(FBDialect):
+    driver = 'kinterbasdb'
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+
+    def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
+        super(Firebird_kinterbasdb, self).__init__(**kwargs)
+
+        self.type_conv = type_conv
+        self.concurrency_level = concurrency_level
+
+    @classmethod
+    def dbapi(cls):
+        k = __import__('kinterbasdb')
+        return k
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if opts.get('port'):
+            opts['host'] = "%s/%s" % (opts['host'], opts['port'])
+            del opts['port']
+        opts.update(url.query)
+
+        type_conv = opts.pop('type_conv', self.type_conv)
+        concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
+        global _initialized_kb
+        if not _initialized_kb and self.dbapi is not None:
+            _initialized_kb = True
+            self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
+        return ([], opts)
+
+    def _get_server_version_info(self, connection):
+        """Get the version of the Firebird server used by a connection.
+
+        Returns a tuple of (`major`, `minor`, `build`), three integers
+        representing the version of the attached server.
+        """
+
+        # This is the simpler approach (the other uses the services api),
+        # that for backward compatibility reasons returns a string like
+        #   LI-V6.3.3.12981 Firebird 2.0
+        # where the first version is a fake one resembling the old
+        # Interbase signature. This is more than enough for our purposes,
+        # as this is mainly (only?) used by the testsuite.
+
+        from re import match
+
+        fbconn = connection.connection
+        version = fbconn.server_version
+        m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
+        if not m:
+            raise AssertionError("Could not determine version from string '%s'" % version)
+        return tuple([int(x) for x in m.group(5, 6, 4)])
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.OperationalError):
+            return 'Unable to complete network request to host' in str(e)
+        elif isinstance(e, self.dbapi.ProgrammingError):
+            msg = str(e)
+            return ('Invalid connection state' in msg or
+                    'Invalid cursor state' in msg)
+        else:
+            return False
+
+dialect = Firebird_kinterbasdb
diff --git a/lib/sqlalchemy/dialects/informix/__init__.py b/lib/sqlalchemy/dialects/informix/__init__.py
new file mode 100644 (file)
index 0000000..f2fcc76
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.informix import base, informixdb
+
+base.dialect = informixdb.dialect
\ No newline at end of file
similarity index 59%
rename from lib/sqlalchemy/databases/informix.py
rename to lib/sqlalchemy/dialects/informix/base.py
index 4476af3b9c25ea85c2fad39528f3f6014d6b13cc..b69748fcf1155fe5ea8357b8d67c9e229d98960c 100644 (file)
@@ -5,6 +5,12 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Support for the Informix database.
+
+This dialect is *not* tested on SQLAlchemy 0.6.
+
+"""
+
 
 import datetime
 
@@ -14,55 +20,7 @@ from sqlalchemy.engine import default
 from sqlalchemy import types as sqltypes
 
 
-# for offset
-
-class informix_cursor(object):
-    def __init__( self , con ):
-        self.__cursor = con.cursor()
-        self.rowcount = 0
-
-    def offset( self , n ):
-        if n > 0:
-            self.fetchmany( n )
-            self.rowcount = self.__cursor.rowcount - n
-            if self.rowcount < 0:
-                self.rowcount = 0
-        else:
-            self.rowcount = self.__cursor.rowcount
-
-    def execute( self , sql , params ):
-        if params is None or len( params ) == 0:
-            params = []
-
-        return self.__cursor.execute( sql , params )
-
-    def __getattr__( self , name ):
-        if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
-            return getattr( self.__cursor , name )
-
-class InfoNumeric(sqltypes.Numeric):
-    def get_col_spec(self):
-        if not self.precision:
-            return 'NUMERIC'
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class InfoInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class InfoSmallInteger(sqltypes.Smallinteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class InfoDate(sqltypes.Date):
-    def get_col_spec( self ):
-        return "DATE"
-
 class InfoDateTime(sqltypes.DateTime ):
-    def get_col_spec(self):
-        return "DATETIME YEAR TO SECOND"
-
     def bind_processor(self, dialect):
         def process(value):
             if value is not None:
@@ -72,9 +30,6 @@ class InfoDateTime(sqltypes.DateTime ):
         return process
 
 class InfoTime(sqltypes.Time ):
-    def get_col_spec(self):
-        return "DATETIME HOUR TO SECOND"
-
     def bind_processor(self, dialect):
         def process(value):
             if value is not None:
@@ -91,35 +46,8 @@ class InfoTime(sqltypes.Time ):
                 return value
         return process
 
-class InfoText(sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR(255)"
-
-class InfoString(sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value == '':
-                return None
-            else:
-                return value
-        return process
-
-class InfoChar(sqltypes.CHAR):
-    def get_col_spec(self):
-        return "CHAR(%(length)s)" % {'length' : self.length}
-
-class InfoBinary(sqltypes.Binary):
-    def get_col_spec(self):
-        return "BYTE"
 
 class InfoBoolean(sqltypes.Boolean):
-    default_type = 'NUM'
-    def get_col_spec(self):
-        return "SMALLINT"
-
     def result_processor(self, dialect):
         def process(value):
             if value is None:
@@ -140,104 +68,156 @@ class InfoBoolean(sqltypes.Boolean):
         return process
 
 colspecs = {
-    sqltypes.Integer : InfoInteger,
-    sqltypes.Smallinteger : InfoSmallInteger,
-    sqltypes.Numeric : InfoNumeric,
-    sqltypes.Float : InfoNumeric,
     sqltypes.DateTime : InfoDateTime,
-    sqltypes.Date : InfoDate,
     sqltypes.Time: InfoTime,
-    sqltypes.String : InfoString,
-    sqltypes.Binary : InfoBinary,
     sqltypes.Boolean : InfoBoolean,
-    sqltypes.Text : InfoText,
-    sqltypes.CHAR: InfoChar,
 }
 
 
 ischema_names = {
-    0   : InfoString,       # CHAR
-    1   : InfoSmallInteger, # SMALLINT
-    2   : InfoInteger,      # INT
-    3   : InfoNumeric,      # Float
-    3   : InfoNumeric,      # SmallFloat
-    5   : InfoNumeric,      # DECIMAL
-    6   : InfoInteger,      # Serial
-    7   : InfoDate,         # DATE
-    8   : InfoNumeric,      # MONEY
-    10  : InfoDateTime,     # DATETIME
-    11  : InfoBinary,       # BYTE
-    12  : InfoText,         # TEXT
-    13  : InfoString,       # VARCHAR
-    15  : InfoString,       # NCHAR
-    16  : InfoString,       # NVARCHAR
-    17  : InfoInteger,      # INT8
-    18  : InfoInteger,      # Serial8
-    43  : InfoString,       # LVARCHAR
-    -1  : InfoBinary,       # BLOB
-    -1  : InfoText,         # CLOB
+    0   : sqltypes.CHAR,       # CHAR
+    1   : sqltypes.SMALLINT, # SMALLINT
+    2   : sqltypes.INTEGER,      # INT
+    3   : sqltypes.FLOAT,      # Float
+    3   : sqltypes.Float,      # SmallFloat
+    5   : sqltypes.DECIMAL,      # DECIMAL
+    6   : sqltypes.Integer,      # Serial
+    7   : sqltypes.DATE,         # DATE
+    8   : sqltypes.Numeric,      # MONEY
+    10  : sqltypes.DATETIME,     # DATETIME
+    11  : sqltypes.Binary,       # BYTE
+    12  : sqltypes.TEXT,         # TEXT
+    13  : sqltypes.VARCHAR,       # VARCHAR
+    15  : sqltypes.NCHAR,       # NCHAR
+    16  : sqltypes.NVARCHAR,       # NVARCHAR
+    17  : sqltypes.Integer,      # INT8
+    18  : sqltypes.Integer,      # Serial8
+    43  : sqltypes.String,       # LVARCHAR
+    -1  : sqltypes.BLOB,       # BLOB
+    -1  : sqltypes.CLOB,         # CLOB
 }
 
 
-class InfoExecutionContext(default.DefaultExecutionContext):
-    # cursor.sqlerrd
-    # 0 - estimated number of rows returned
-    # 1 - serial value after insert or ISAM error code
-    # 2 - number of rows processed
-    # 3 - estimated cost
-    # 4 - offset of the error into the SQL statement
-    # 5 - rowid after insert
-    def post_exec(self):
-        if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
-            self._last_inserted_ids = [self.cursor.sqlerrd[1]]
-        elif hasattr( self.compiled , 'offset' ):
-            self.cursor.offset( self.compiled.offset )
-        super(InfoExecutionContext, self).post_exec()
-
-    def create_cursor( self ):
-        return informix_cursor( self.connection.connection )
-
-class InfoDialect(default.DefaultDialect):
-    name = 'informix'
-    default_paramstyle = 'qmark'
-    # for informix 7.31
-    max_identifier_length = 18
+class InfoTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_DATETIME(self, type_):
+        return "DATETIME YEAR TO SECOND"
+    
+    def visit_TIME(self, type_):
+        return "DATETIME HOUR TO SECOND"
+    
+    def visit_binary(self, type_):
+        return "BYTE"
+    
+    def visit_boolean(self, type_):
+        return "SMALLINT"
+        
+class InfoSQLCompiler(compiler.SQLCompiler):
 
-    def __init__(self, use_ansi=True, **kwargs):
-        self.use_ansi = use_ansi
-        default.DefaultDialect.__init__(self, **kwargs)
+    def __init__(self, *args, **kwargs):
+        self.limit = 0
+        self.offset = 0
 
-    def dbapi(cls):
-        import informixdb
-        return informixdb
-    dbapi = classmethod(dbapi)
+        compiler.SQLCompiler.__init__( self , *args, **kwargs )
+
+    def default_from(self):
+        return " from systables where tabname = 'systables' "
 
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.OperationalError):
-            return 'closed the connection' in str(e) or 'connection not open' in str(e)
+    def get_select_precolumns( self , select ):
+        s = select._distinct and "DISTINCT " or ""
+        # only has limit
+        if select._limit:
+            off = select._offset or 0
+            s += " FIRST %s " % ( select._limit + off )
         else:
-            return False
+            s += ""
+        return s
 
-    def do_begin(self , connect ):
-        cu = connect.cursor()
-        cu.execute( 'SET LOCK MODE TO WAIT' )
-        #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
+    def visit_select(self, select):
+        if select._offset:
+            self.offset = select._offset
+            self.limit  = select._limit or 0
+        # the column in order by clause must in select too
+
+        def __label( c ):
+            try:
+                return c._label.lower()
+            except:
+                return ''
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
+        # TODO: dont modify the original select, generate a new one
+        a = [ __label(c) for c in select._raw_columns ]
+        for c in select._order_by_clause.clauses:
+            if ( __label(c) not in a ):
+                select.append_column( c )
+
+        return compiler.SQLCompiler.visit_select(self, select)
+
+    def limit_clause(self, select):
+        return ""
 
-    def create_connect_args(self, url):
-        if url.host:
-            dsn = '%s@%s' % ( url.database , url.host )
+    def visit_function( self , func ):
+        if func.name.lower() == 'current_date':
+            return "today"
+        elif func.name.lower() == 'current_time':
+            return "CURRENT HOUR TO SECOND"
+        elif func.name.lower() in ( 'current_timestamp' , 'now' ):
+            return "CURRENT YEAR TO SECOND"
         else:
-            dsn = url.database
+            return compiler.SQLCompiler.visit_function( self , func )
 
-        if url.username:
-            opt = { 'user':url.username , 'password': url.password }
+    def visit_clauselist(self, list, **kwargs):
+        return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
+
+class InfoDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, first_pk=False):
+        colspec = self.preparer.format_column(column)
+        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
+           isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
+            colspec += " SERIAL"
+            self.has_serial = True
         else:
-            opt = {}
+            colspec += " " + self.dialect.type_compiler.process(column.type)
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+
+        return colspec
+
+    def post_create_table(self, table):
+        if hasattr( self , 'has_serial' ):
+            del self.has_serial
+        return ''
+
+class InfoIdentifierPreparer(compiler.IdentifierPreparer):
+    def __init__(self, dialect):
+        super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
+
+    def format_constraint(self, constraint):
+        # informix doesnt support names for constraints
+        return ''
+        
+    def _requires_quotes(self, value):
+        return False
+
+class InformixDialect(default.DefaultDialect):
+    name = 'informix'
+    # for informix 7.31
+    max_identifier_length = 18
+    type_compiler = InfoTypeCompiler
+    poolclass = pool.SingletonThreadPool
+    statement_compiler = InfoSQLCompiler
+    ddl_compiler = InfoDDLCompiler
+    preparer = InfoIdentifierPreparer
+    colspecs = colspecs
+    ischema_names = ischema_names
 
-        return ([dsn], opt)
+    def do_begin(self , connect ):
+        cu = connect.cursor()
+        cu.execute( 'SET LOCK MODE TO WAIT' )
+        #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
 
     def table_names(self, connection, schema):
         s = "select tabname from systables"
@@ -352,142 +332,3 @@ class InfoDialect(default.DefaultDialect):
         for cons_name, cons_type, local_column in rows:
             table.primary_key.add( table.c[local_column] )
 
-class InfoCompiler(compiler.DefaultCompiler):
-    """Info compiler modifies the lexical structure of Select statements to work under
-    non-ANSI configured Oracle databases, if the use_ansi flag is False."""
-
-    def __init__(self, *args, **kwargs):
-        self.limit = 0
-        self.offset = 0
-
-        compiler.DefaultCompiler.__init__( self , *args, **kwargs )
-
-    def default_from(self):
-        return " from systables where tabname = 'systables' "
-
-    def get_select_precolumns( self , select ):
-        s = select._distinct and "DISTINCT " or ""
-        # only has limit
-        if select._limit:
-            off = select._offset or 0
-            s += " FIRST %s " % ( select._limit + off )
-        else:
-            s += ""
-        return s
-
-    def visit_select(self, select):
-        if select._offset:
-            self.offset = select._offset
-            self.limit  = select._limit or 0
-        # the column in order by clause must in select too
-
-        def __label( c ):
-            try:
-                return c._label.lower()
-            except:
-                return ''
-
-        # TODO: dont modify the original select, generate a new one
-        a = [ __label(c) for c in select._raw_columns ]
-        for c in select._order_by_clause.clauses:
-            if ( __label(c) not in a ):
-                select.append_column( c )
-
-        return compiler.DefaultCompiler.visit_select(self, select)
-
-    def limit_clause(self, select):
-        return ""
-
-    def visit_function( self , func ):
-        if func.name.lower() == 'current_date':
-            return "today"
-        elif func.name.lower() == 'current_time':
-            return "CURRENT HOUR TO SECOND"
-        elif func.name.lower() in ( 'current_timestamp' , 'now' ):
-            return "CURRENT YEAR TO SECOND"
-        else:
-            return compiler.DefaultCompiler.visit_function( self , func )
-
-    def visit_clauselist(self, list, **kwargs):
-        return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
-
-class InfoSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, first_pk=False):
-        colspec = self.preparer.format_column(column)
-        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
-           isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
-            colspec += " SERIAL"
-            self.has_serial = True
-        else:
-            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-            default = self.get_column_default_string(column)
-            if default is not None:
-                colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-
-        return colspec
-
-    def post_create_table(self, table):
-        if hasattr( self , 'has_serial' ):
-            del self.has_serial
-        return ''
-
-    def visit_primary_key_constraint(self, constraint):
-        # for informix 7.31 not support constraint name
-        name = constraint.name
-        constraint.name = None
-        super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint)
-        constraint.name = name
-
-    def visit_unique_constraint(self, constraint):
-        # for informix 7.31 not support constraint name
-        name = constraint.name
-        constraint.name = None
-        super(InfoSchemaGenerator, self).visit_unique_constraint(constraint)
-        constraint.name = name
-
-    def visit_foreign_key_constraint( self , constraint ):
-        if constraint.name is not None:
-            constraint.use_alter = True
-        else:
-            super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint )
-
-    def define_foreign_key(self, constraint):
-        # for informix 7.31 not support constraint name
-        if constraint.use_alter:
-            name = constraint.name
-            constraint.name = None
-            self.append( "CONSTRAINT " )
-            super(InfoSchemaGenerator, self).define_foreign_key(constraint)
-            constraint.name = name
-            if name is not None:
-                self.append( " CONSTRAINT " + name )
-        else:
-            super(InfoSchemaGenerator, self).define_foreign_key(constraint)
-
-    def visit_index(self, index):
-        if len( index.columns ) == 1 and index.columns[0].foreign_key:
-            return
-        super(InfoSchemaGenerator, self).visit_index(index)
-
-class InfoIdentifierPreparer(compiler.IdentifierPreparer):
-    def __init__(self, dialect):
-        super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
-
-    def _requires_quotes(self, value):
-        return False
-
-class InfoSchemaDropper(compiler.SchemaDropper):
-    def drop_foreignkey(self, constraint):
-        if constraint.name is not None:
-            super( InfoSchemaDropper , self ).drop_foreignkey( constraint )
-
-dialect = InfoDialect
-poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = InfoCompiler
-dialect.schemagenerator = InfoSchemaGenerator
-dialect.schemadropper = InfoSchemaDropper
-dialect.preparer = InfoIdentifierPreparer
-dialect.execution_ctx_cls = InfoExecutionContext
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/informix/informixdb.py b/lib/sqlalchemy/dialects/informix/informixdb.py
new file mode 100644 (file)
index 0000000..4e929e0
--- /dev/null
@@ -0,0 +1,79 @@
+from sqlalchemy.dialects.informix.base import InformixDialect
+from sqlalchemy.engine import default
+
+# for offset
+
+class informix_cursor(object):
+    def __init__( self , con ):
+        self.__cursor = con.cursor()
+        self.rowcount = 0
+
+    def offset( self , n ):
+        if n > 0:
+            self.fetchmany( n )
+            self.rowcount = self.__cursor.rowcount - n
+            if self.rowcount < 0:
+                self.rowcount = 0
+        else:
+            self.rowcount = self.__cursor.rowcount
+
+    def execute( self , sql , params ):
+        if params is None or len( params ) == 0:
+            params = []
+
+        return self.__cursor.execute( sql , params )
+
+    def __getattr__( self , name ):
+        if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
+            return getattr( self.__cursor , name )
+
+
+class InfoExecutionContext(default.DefaultExecutionContext):
+    # cursor.sqlerrd
+    # 0 - estimated number of rows returned
+    # 1 - serial value after insert or ISAM error code
+    # 2 - number of rows processed
+    # 3 - estimated cost
+    # 4 - offset of the error into the SQL statement
+    # 5 - rowid after insert
+    def post_exec(self):
+        if getattr(self.compiled, "isinsert", False) and self.inserted_primary_key is None:
+            self._last_inserted_ids = [self.cursor.sqlerrd[1]]
+        elif hasattr( self.compiled , 'offset' ):
+            self.cursor.offset( self.compiled.offset )
+
+    def create_cursor( self ):
+        return informix_cursor( self.connection.connection )
+
+
+class Informix_informixdb(InformixDialect):
+    driver = 'informixdb'
+    default_paramstyle = 'qmark'
+    execution_context_cls = InfoExecutionContext
+    
+    @classmethod
+    def dbapi(cls):
+        return __import__('informixdb')
+
+    def create_connect_args(self, url):
+        if url.host:
+            dsn = '%s@%s' % ( url.database , url.host )
+        else:
+            dsn = url.database
+
+        if url.username:
+            opt = { 'user':url.username , 'password': url.password }
+        else:
+            opt = {}
+
+        return ([dsn], opt)
+
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.OperationalError):
+            return 'closed the connection' in str(e) or 'connection not open' in str(e)
+        else:
+            return False
+
+
+dialect = Informix_informixdb
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/maxdb/__init__.py b/lib/sqlalchemy/dialects/maxdb/__init__.py
new file mode 100644 (file)
index 0000000..3f12448
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.maxdb import base, sapdb
+
+base.dialect = sapdb.dialect
\ No newline at end of file
similarity index 90%
rename from lib/sqlalchemy/databases/maxdb.py
rename to lib/sqlalchemy/dialects/maxdb/base.py
index 693295054e4a3fd5aebb98f79a4cfa07c3ff7eea..1ec95e03b4dfa22967123228aa9933ff58dd5c3f 100644 (file)
@@ -5,7 +5,7 @@
 
 """Support for the MaxDB database.
 
-TODO: More module docs!  MaxDB support is currently experimental.
+This dialect is *not* tested on SQLAlchemy 0.6.
 
 Overview
 --------
@@ -65,13 +65,6 @@ from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy import types as sqltypes
 
 
-__all__ = [
-    'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger',
-    'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp',
-    'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob',
-    ]
-
-
 class _StringType(sqltypes.String):
     _type = None
 
@@ -79,16 +72,6 @@ class _StringType(sqltypes.String):
         super(_StringType, self).__init__(length=length, **kw)
         self.encoding = encoding
 
-    def get_col_spec(self):
-        if self.length is None:
-            spec = 'LONG'
-        else:
-            spec = '%s(%s)' % (self._type, self.length)
-
-        if self.encoding is not None:
-            spec = ' '.join([spec, self.encoding.upper()])
-        return spec
-
     def bind_processor(self, dialect):
         if self.encoding == 'unicode':
             return None
@@ -156,16 +139,6 @@ class MaxText(_StringType):
         return spec
 
 
-class MaxInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return 'INTEGER'
-
-
-class MaxSmallInteger(MaxInteger):
-    def get_col_spec(self):
-        return 'SMALLINT'
-
-
 class MaxNumeric(sqltypes.Numeric):
     """The FIXED (also NUMERIC, DECIMAL) data type."""
 
@@ -177,29 +150,7 @@ class MaxNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
         return None
 
-    def get_col_spec(self):
-        if self.scale and self.precision:
-            return 'FIXED(%s, %s)' % (self.precision, self.scale)
-        elif self.precision:
-            return 'FIXED(%s)' % self.precision
-        else:
-            return 'INTEGER'
-
-
-class MaxFloat(sqltypes.Float):
-    """The FLOAT data type."""
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return 'FLOAT'
-        else:
-            return 'FLOAT(%s)' % (self.precision,)
-
-
 class MaxTimestamp(sqltypes.DateTime):
-    def get_col_spec(self):
-        return 'TIMESTAMP'
-
     def bind_processor(self, dialect):
         def process(value):
             if value is None:
@@ -242,9 +193,6 @@ class MaxTimestamp(sqltypes.DateTime):
 
 
 class MaxDate(sqltypes.Date):
-    def get_col_spec(self):
-        return 'DATE'
-
     def bind_processor(self, dialect):
         def process(value):
             if value is None:
@@ -279,9 +227,6 @@ class MaxDate(sqltypes.Date):
 
 
 class MaxTime(sqltypes.Time):
-    def get_col_spec(self):
-        return 'TIME'
-
     def bind_processor(self, dialect):
         def process(value):
             if value is None:
@@ -316,15 +261,7 @@ class MaxTime(sqltypes.Time):
         return process
 
 
-class MaxBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return 'BOOLEAN'
-
-
 class MaxBlob(sqltypes.Binary):
-    def get_col_spec(self):
-        return 'LONG BYTE'
-
     def bind_processor(self, dialect):
         def process(value):
             if value is None:
@@ -341,18 +278,54 @@ class MaxBlob(sqltypes.Binary):
                 return value.read(value.remainingLength())
         return process
 
+class MaxDBTypeCompiler(compiler.GenericTypeCompiler):
+    def _string_spec(self, string_spec, type_):
+        if type_.length is None:
+            spec = 'LONG'
+        else:
+            spec = '%s(%s)' % (string_spec, type_.length)
+
+        if getattr(type_, 'encoding'):
+            spec = ' '.join([spec, getattr(type_, 'encoding').upper()])
+        return spec
+
+    def visit_text(self, type_):
+        spec = 'LONG'
+        if getattr(type_, 'encoding', None):
+            spec = ' '.join((spec, type_.encoding))
+        elif type_.convert_unicode:
+            spec = ' '.join((spec, 'UNICODE'))
+
+        return spec
 
+    def visit_char(self, type_):
+        return self._string_spec("CHAR", type_)
+
+    def visit_string(self, type_):
+        return self._string_spec("VARCHAR", type_)
+
+    def visit_binary(self, type_):
+        return "LONG BYTE"
+    
+    def visit_numeric(self, type_):
+        if type_.scale and type_.precision:
+            return 'FIXED(%s, %s)' % (type_.precision, type_.scale)
+        elif type_.precision:
+            return 'FIXED(%s)' % type_.precision
+        else:
+            return 'INTEGER'
+    
+    def visit_BOOLEAN(self, type_):
+        return "BOOLEAN"
+        
 colspecs = {
-    sqltypes.Integer: MaxInteger,
-    sqltypes.Smallinteger: MaxSmallInteger,
     sqltypes.Numeric: MaxNumeric,
-    sqltypes.Float: MaxFloat,
     sqltypes.DateTime: MaxTimestamp,
     sqltypes.Date: MaxDate,
     sqltypes.Time: MaxTime,
     sqltypes.String: MaxString,
+    sqltypes.Unicode:MaxUnicode,
     sqltypes.Binary: MaxBlob,
-    sqltypes.Boolean: MaxBoolean,
     sqltypes.Text: MaxText,
     sqltypes.CHAR: MaxChar,
     sqltypes.TIMESTAMP: MaxTimestamp,
@@ -361,25 +334,25 @@ colspecs = {
     }
 
 ischema_names = {
-    'boolean': MaxBoolean,
-    'char': MaxChar,
-    'character': MaxChar,
-    'date': MaxDate,
-    'fixed': MaxNumeric,
-    'float': MaxFloat,
-    'int': MaxInteger,
-    'integer': MaxInteger,
-    'long binary': MaxBlob,
-    'long unicode': MaxText,
-    'long': MaxText,
-    'long': MaxText,
-    'smallint': MaxSmallInteger,
-    'time': MaxTime,
-    'timestamp': MaxTimestamp,
-    'varchar': MaxString,
+    'boolean': sqltypes.BOOLEAN,
+    'char': sqltypes.CHAR,
+    'character': sqltypes.CHAR,
+    'date': sqltypes.DATE,
+    'fixed': sqltypes.Numeric,
+    'float': sqltypes.FLOAT,
+    'int': sqltypes.INT,
+    'integer': sqltypes.INT,
+    'long binary': sqltypes.BLOB,
+    'long unicode': sqltypes.Text,
+    'long': sqltypes.Text,
+    'long': sqltypes.Text,
+    'smallint': sqltypes.SmallInteger,
+    'time': sqltypes.Time,
+    'timestamp': sqltypes.TIMESTAMP,
+    'varchar': sqltypes.VARCHAR,
     }
 
-
+# TODO: migrate this to sapdb.py
 class MaxDBExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
         # DB-API bug: if there were any functions as values,
@@ -421,6 +394,12 @@ class MaxDBExecutionContext(default.DefaultExecutionContext):
                     return MaxDBResultProxy(self)
         return engine_base.ResultProxy(self)
 
+    @property
+    def rowcount(self):
+        if hasattr(self, '_rowcount'):
+            return self._rowcount
+        else:
+            return self.cursor.rowcount
 
 class MaxDBCachedColumnRow(engine_base.RowProxy):
     """A RowProxy that only runs result_processors once per column."""
@@ -454,272 +433,17 @@ class MaxDBCachedColumnRow(engine_base.RowProxy):
         else:
             return self._get_col(key)
 
-    def __getattr__(self, name):
-        try:
-            return self._get_col(name)
-        except KeyError:
-            raise AttributeError(name)
-
-
-class MaxDBResultProxy(engine_base.ResultProxy):
-    _process_row = MaxDBCachedColumnRow
-
-
-class MaxDBDialect(default.DefaultDialect):
-    name = 'maxdb'
-    supports_alter = True
-    supports_unicode_statements = True
-    max_identifier_length = 32
-    supports_sane_rowcount = True
-    supports_sane_multi_rowcount = False
-    preexecute_pk_sequences = True
-
-    # MaxDB-specific
-    datetimeformat = 'internal'
-
-    def __init__(self, _raise_known_sql_errors=False, **kw):
-        super(MaxDBDialect, self).__init__(**kw)
-        self._raise_known = _raise_known_sql_errors
-
-        if self.dbapi is None:
-            self.dbapi_type_map = {}
-        else:
-            self.dbapi_type_map = {
-                'Long Binary': MaxBlob(),
-                'Long byte_t': MaxBlob(),
-                'Long Unicode': MaxText(),
-                'Timestamp': MaxTimestamp(),
-                'Date': MaxDate(),
-                'Time': MaxTime(),
-                datetime.datetime: MaxTimestamp(),
-                datetime.date: MaxDate(),
-                datetime.time: MaxTime(),
-            }
-
-    def dbapi(cls):
-        from sapdb import dbapi as _dbapi
-        return _dbapi
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        opts.update(url.query)
-        return [], opts
-
-    def type_descriptor(self, typeobj):
-        if isinstance(typeobj, type):
-            typeobj = typeobj()
-        if isinstance(typeobj, sqltypes.Unicode):
-            return typeobj.adapt(MaxUnicode)
-        else:
-            return sqltypes.adapt_type(typeobj, colspecs)
-
-    def do_execute(self, cursor, statement, parameters, context=None):
-        res = cursor.execute(statement, parameters)
-        if isinstance(res, int) and context is not None:
-            context._rowcount = res
-
-    def do_release_savepoint(self, connection, name):
-        # Does MaxDB truly support RELEASE SAVEPOINT <id>?  All my attempts
-        # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
-        # BEGIN SQLSTATE: I7065"
-        # Note that ROLLBACK TO works fine.  In theory, a RELEASE should
-        # just free up some transactional resources early, before the overall
-        # COMMIT/ROLLBACK so omitting it should be relatively ok.
-        pass
-
-    def get_default_schema_name(self, connection):
-        try:
-            return self._default_schema_name
-        except AttributeError:
-            name = self.identifier_preparer._normalize_name(
-                connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
-            self._default_schema_name = name
-            return name
-
-    def has_table(self, connection, table_name, schema=None):
-        denormalize = self.identifier_preparer._denormalize_name
-        bind = [denormalize(table_name)]
-        if schema is None:
-            sql = ("SELECT tablename FROM TABLES "
-                   "WHERE TABLES.TABLENAME=? AND"
-                   "  TABLES.SCHEMANAME=CURRENT_SCHEMA ")
-        else:
-            sql = ("SELECT tablename FROM TABLES "
-                   "WHERE TABLES.TABLENAME = ? AND"
-                   "  TABLES.SCHEMANAME=? ")
-            bind.append(denormalize(schema))
-
-        rp = connection.execute(sql, bind)
-        found = bool(rp.fetchone())
-        rp.close()
-        return found
-
-    def table_names(self, connection, schema):
-        if schema is None:
-            sql = (" SELECT TABLENAME FROM TABLES WHERE "
-                   " SCHEMANAME=CURRENT_SCHEMA ")
-            rs = connection.execute(sql)
-        else:
-            sql = (" SELECT TABLENAME FROM TABLES WHERE "
-                   " SCHEMANAME=? ")
-            matchname = self.identifier_preparer._denormalize_name(schema)
-            rs = connection.execute(sql, matchname)
-        normalize = self.identifier_preparer._normalize_name
-        return [normalize(row[0]) for row in rs]
-
-    def reflecttable(self, connection, table, include_columns):
-        denormalize = self.identifier_preparer._denormalize_name
-        normalize = self.identifier_preparer._normalize_name
-
-        st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
-              '  NULLABLE, "DEFAULT", DEFAULTFUNCTION '
-              'FROM COLUMNS '
-              'WHERE TABLENAME=? AND SCHEMANAME=%s '
-              'ORDER BY POS')
-
-        fk = ('SELECT COLUMNNAME, FKEYNAME, '
-              '  REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
-              '  (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
-              '   THEN 1 ELSE 0 END) AS in_schema '
-              'FROM FOREIGNKEYCOLUMNS '
-              'WHERE TABLENAME=? AND SCHEMANAME=%s '
-              'ORDER BY FKEYNAME ')
-
-        params = [denormalize(table.name)]
-        if not table.schema:
-            st = st % 'CURRENT_SCHEMA'
-            fk = fk % 'CURRENT_SCHEMA'
-        else:
-            st = st % '?'
-            fk = fk % '?'
-            params.append(denormalize(table.schema))
-
-        rows = connection.execute(st, params).fetchall()
-        if not rows:
-            raise exc.NoSuchTableError(table.fullname)
-
-        include_columns = set(include_columns or [])
-
-        for row in rows:
-            (name, mode, col_type, encoding, length, scale,
-             nullable, constant_def, func_def) = row
-
-            name = normalize(name)
-
-            if include_columns and name not in include_columns:
-                continue
-
-            type_args, type_kw = [], {}
-            if col_type == 'FIXED':
-                type_args = length, scale
-                # Convert FIXED(10) DEFAULT SERIAL to our Integer
-                if (scale == 0 and
-                    func_def is not None and func_def.startswith('SERIAL')):
-                    col_type = 'INTEGER'
-                    type_args = length,
-            elif col_type in 'FLOAT':
-                type_args = length,
-            elif col_type in ('CHAR', 'VARCHAR'):
-                type_args = length,
-                type_kw['encoding'] = encoding
-            elif col_type == 'LONG':
-                type_kw['encoding'] = encoding
-
-            try:
-                type_cls = ischema_names[col_type.lower()]
-                type_instance = type_cls(*type_args, **type_kw)
-            except KeyError:
-                util.warn("Did not recognize type '%s' of column '%s'" %
-                          (col_type, name))
-                type_instance = sqltypes.NullType
-
-            col_kw = {'autoincrement': False}
-            col_kw['nullable'] = (nullable == 'YES')
-            col_kw['primary_key'] = (mode == 'KEY')
-
-            if func_def is not None:
-                if func_def.startswith('SERIAL'):
-                    if col_kw['primary_key']:
-                        # No special default- let the standard autoincrement
-                        # support handle SERIAL pk columns.
-                        col_kw['autoincrement'] = True
-                    else:
-                        # strip current numbering
-                        col_kw['server_default'] = schema.DefaultClause(
-                            sql.text('SERIAL'))
-                        col_kw['autoincrement'] = True
-                else:
-                    col_kw['server_default'] = schema.DefaultClause(
-                        sql.text(func_def))
-            elif constant_def is not None:
-                col_kw['server_default'] = schema.DefaultClause(sql.text(
-                    "'%s'" % constant_def.replace("'", "''")))
-
-            table.append_column(schema.Column(name, type_instance, **col_kw))
-
-        fk_sets = itertools.groupby(connection.execute(fk, params),
-                                    lambda row: row.FKEYNAME)
-        for fkeyname, fkey in fk_sets:
-            fkey = list(fkey)
-            if include_columns:
-                key_cols = set([r.COLUMNNAME for r in fkey])
-                if key_cols != include_columns:
-                    continue
-
-            columns, referants = [], []
-            quote = self.identifier_preparer._maybe_quote_identifier
-
-            for row in fkey:
-                columns.append(normalize(row.COLUMNNAME))
-                if table.schema or not row.in_schema:
-                    referants.append('.'.join(
-                        [quote(normalize(row[c]))
-                         for c in ('REFSCHEMANAME', 'REFTABLENAME',
-                                   'REFCOLUMNNAME')]))
-                else:
-                    referants.append('.'.join(
-                        [quote(normalize(row[c]))
-                         for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
-
-            constraint_kw = {'name': fkeyname.lower()}
-            if fkey[0].RULE is not None:
-                rule = fkey[0].RULE
-                if rule.startswith('DELETE '):
-                    rule = rule[7:]
-                constraint_kw['ondelete'] = rule
-
-            table_kw = {}
-            if table.schema or not row.in_schema:
-                table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
-
-            ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
-                                            table_kw.get('schema'))
-            if ref_key not in table.metadata.tables:
-                schema.Table(normalize(fkey[0].REFTABLENAME),
-                             table.metadata,
-                             autoload=True, autoload_with=connection,
-                             **table_kw)
-
-            constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
-                                                     **constraint_kw)
-            table.append_constraint(constraint)
-
-    def has_sequence(self, connection, name):
-        # [ticket:726] makes this schema-aware.
-        denormalize = self.identifier_preparer._denormalize_name
-        sql = ("SELECT sequence_name FROM SEQUENCES "
-               "WHERE SEQUENCE_NAME=? ")
+    def __getattr__(self, name):
+        try:
+            return self._get_col(name)
+        except KeyError:
+            raise AttributeError(name)
 
-        rp = connection.execute(sql, denormalize(name))
-        found = bool(rp.fetchone())
-        rp.close()
-        return found
 
+class MaxDBResultProxy(engine_base.ResultProxy):
+    _process_row = MaxDBCachedColumnRow
 
-class MaxDBCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
-    operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
+class MaxDBCompiler(compiler.SQLCompiler):
 
     function_conversion = {
         'CURRENT_DATE': 'DATE',
@@ -734,6 +458,9 @@ class MaxDBCompiler(compiler.DefaultCompiler):
         'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
         'UTCDATE', 'UTCDIFF'])
 
+    def visit_mod(self, binary, **kw):
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
     def default_from(self):
         return ' FROM DUAL'
 
@@ -756,11 +483,13 @@ class MaxDBCompiler(compiler.DefaultCompiler):
         else:
             return " WITH LOCK EXCLUSIVE"
 
-    def apply_function_parens(self, func):
-        if func.name.upper() in self.bare_functions:
-            return len(func.clauses) > 0
+    def function_argspec(self, fn, **kw):
+        if fn.name.upper() in self.bare_functions:
+            return ""
+        elif len(fn.clauses) > 0:
+            return compiler.SQLCompiler.function_argspec(self, fn, **kw)
         else:
-            return True
+            return ""
 
     def visit_function(self, fn, **kw):
         transform = self.function_conversion.get(fn.name.upper(), None)
@@ -947,10 +676,10 @@ class MaxDBIdentifierPreparer(compiler.IdentifierPreparer):
             return name
 
 
-class MaxDBSchemaGenerator(compiler.SchemaGenerator):
+class MaxDBDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kw):
         colspec = [self.preparer.format_column(column),
-                   column.type.dialect_impl(self.dialect).get_col_spec()]
+                   self.dialect.type_compiler.process(column.type)]
 
         if not column.nullable:
             colspec.append('NOT NULL')
@@ -996,7 +725,7 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator):
         else:
             return None
 
-    def visit_sequence(self, sequence):
+    def visit_create_sequence(self, create):
         """Creates a SEQUENCE.
 
         TODO: move to module doc?
@@ -1024,7 +753,8 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator):
         maxdb_no_cache
           Defaults to False.  If true, sets NOCACHE.
         """
-
+        sequence = create.element
+        
         if (not sequence.optional and
             (not self.checkfirst or
              not self.dialect.has_sequence(self.connection, sequence.name))):
@@ -1061,18 +791,250 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator):
             elif opts.get('no_cache', False):
                 ddl.append('NOCACHE')
 
-            self.append(' '.join(ddl))
-            self.execute()
+            return ' '.join(ddl)
 
 
-class MaxDBSchemaDropper(compiler.SchemaDropper):
-    def visit_sequence(self, sequence):
-        if (not sequence.optional and
-            (not self.checkfirst or
-             self.dialect.has_sequence(self.connection, sequence.name))):
-            self.append("DROP SEQUENCE %s" %
-                        self.preparer.format_sequence(sequence))
-            self.execute()
+class MaxDBDialect(default.DefaultDialect):
+    name = 'maxdb'
+    supports_alter = True
+    supports_unicode_statements = True
+    max_identifier_length = 32
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = False
+
+    preparer = MaxDBIdentifierPreparer
+    statement_compiler = MaxDBCompiler
+    ddl_compiler = MaxDBDDLCompiler
+    defaultrunner = MaxDBDefaultRunner
+    execution_ctx_cls = MaxDBExecutionContext
+
+    colspecs = colspecs
+    ischema_names = ischema_names
+    
+    # MaxDB-specific
+    datetimeformat = 'internal'
+
+    def __init__(self, _raise_known_sql_errors=False, **kw):
+        super(MaxDBDialect, self).__init__(**kw)
+        self._raise_known = _raise_known_sql_errors
+
+        if self.dbapi is None:
+            self.dbapi_type_map = {}
+        else:
+            self.dbapi_type_map = {
+                'Long Binary': MaxBlob(),
+                'Long byte_t': MaxBlob(),
+                'Long Unicode': MaxText(),
+                'Timestamp': MaxTimestamp(),
+                'Date': MaxDate(),
+                'Time': MaxTime(),
+                datetime.datetime: MaxTimestamp(),
+                datetime.date: MaxDate(),
+                datetime.time: MaxTime(),
+            }
+
+    def do_execute(self, cursor, statement, parameters, context=None):
+        res = cursor.execute(statement, parameters)
+        if isinstance(res, int) and context is not None:
+            context._rowcount = res
+
+    def do_release_savepoint(self, connection, name):
+        # Does MaxDB truly support RELEASE SAVEPOINT <id>?  All my attempts
+        # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
+        # BEGIN SQLSTATE: I7065"
+        # Note that ROLLBACK TO works fine.  In theory, a RELEASE should
+        # just free up some transactional resources early, before the overall
+        # COMMIT/ROLLBACK so omitting it should be relatively ok.
+        pass
+
+    def get_default_schema_name(self, connection):
+        try:
+            return self._default_schema_name
+        except AttributeError:
+            name = self.identifier_preparer._normalize_name(
+                connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
+            self._default_schema_name = name
+            return name
+
+    def has_table(self, connection, table_name, schema=None):
+        denormalize = self.identifier_preparer._denormalize_name
+        bind = [denormalize(table_name)]
+        if schema is None:
+            sql = ("SELECT tablename FROM TABLES "
+                   "WHERE TABLES.TABLENAME=? AND"
+                   "  TABLES.SCHEMANAME=CURRENT_SCHEMA ")
+        else:
+            sql = ("SELECT tablename FROM TABLES "
+                   "WHERE TABLES.TABLENAME = ? AND"
+                   "  TABLES.SCHEMANAME=? ")
+            bind.append(denormalize(schema))
+
+        rp = connection.execute(sql, bind)
+        found = bool(rp.fetchone())
+        rp.close()
+        return found
+
+    def table_names(self, connection, schema):
+        if schema is None:
+            sql = (" SELECT TABLENAME FROM TABLES WHERE "
+                   " SCHEMANAME=CURRENT_SCHEMA ")
+            rs = connection.execute(sql)
+        else:
+            sql = (" SELECT TABLENAME FROM TABLES WHERE "
+                   " SCHEMANAME=? ")
+            matchname = self.identifier_preparer._denormalize_name(schema)
+            rs = connection.execute(sql, matchname)
+        normalize = self.identifier_preparer._normalize_name
+        return [normalize(row[0]) for row in rs]
+
+    def reflecttable(self, connection, table, include_columns):
+        denormalize = self.identifier_preparer._denormalize_name
+        normalize = self.identifier_preparer._normalize_name
+
+        st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
+              '  NULLABLE, "DEFAULT", DEFAULTFUNCTION '
+              'FROM COLUMNS '
+              'WHERE TABLENAME=? AND SCHEMANAME=%s '
+              'ORDER BY POS')
+
+        fk = ('SELECT COLUMNNAME, FKEYNAME, '
+              '  REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
+              '  (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
+              '   THEN 1 ELSE 0 END) AS in_schema '
+              'FROM FOREIGNKEYCOLUMNS '
+              'WHERE TABLENAME=? AND SCHEMANAME=%s '
+              'ORDER BY FKEYNAME ')
+
+        params = [denormalize(table.name)]
+        if not table.schema:
+            st = st % 'CURRENT_SCHEMA'
+            fk = fk % 'CURRENT_SCHEMA'
+        else:
+            st = st % '?'
+            fk = fk % '?'
+            params.append(denormalize(table.schema))
+
+        rows = connection.execute(st, params).fetchall()
+        if not rows:
+            raise exc.NoSuchTableError(table.fullname)
+
+        include_columns = set(include_columns or [])
+
+        for row in rows:
+            (name, mode, col_type, encoding, length, scale,
+             nullable, constant_def, func_def) = row
+
+            name = normalize(name)
+
+            if include_columns and name not in include_columns:
+                continue
+
+            type_args, type_kw = [], {}
+            if col_type == 'FIXED':
+                type_args = length, scale
+                # Convert FIXED(10) DEFAULT SERIAL to our Integer
+                if (scale == 0 and
+                    func_def is not None and func_def.startswith('SERIAL')):
+                    col_type = 'INTEGER'
+                    type_args = length,
+            elif col_type in 'FLOAT':
+                type_args = length,
+            elif col_type in ('CHAR', 'VARCHAR'):
+                type_args = length,
+                type_kw['encoding'] = encoding
+            elif col_type == 'LONG':
+                type_kw['encoding'] = encoding
+
+            try:
+                type_cls = ischema_names[col_type.lower()]
+                type_instance = type_cls(*type_args, **type_kw)
+            except KeyError:
+                util.warn("Did not recognize type '%s' of column '%s'" %
+                          (col_type, name))
+                type_instance = sqltypes.NullType
+
+            col_kw = {'autoincrement': False}
+            col_kw['nullable'] = (nullable == 'YES')
+            col_kw['primary_key'] = (mode == 'KEY')
+
+            if func_def is not None:
+                if func_def.startswith('SERIAL'):
+                    if col_kw['primary_key']:
+                        # No special default- let the standard autoincrement
+                        # support handle SERIAL pk columns.
+                        col_kw['autoincrement'] = True
+                    else:
+                        # strip current numbering
+                        col_kw['server_default'] = schema.DefaultClause(
+                            sql.text('SERIAL'))
+                        col_kw['autoincrement'] = True
+                else:
+                    col_kw['server_default'] = schema.DefaultClause(
+                        sql.text(func_def))
+            elif constant_def is not None:
+                col_kw['server_default'] = schema.DefaultClause(sql.text(
+                    "'%s'" % constant_def.replace("'", "''")))
+
+            table.append_column(schema.Column(name, type_instance, **col_kw))
+
+        fk_sets = itertools.groupby(connection.execute(fk, params),
+                                    lambda row: row.FKEYNAME)
+        for fkeyname, fkey in fk_sets:
+            fkey = list(fkey)
+            if include_columns:
+                key_cols = set([r.COLUMNNAME for r in fkey])
+                if key_cols != include_columns:
+                    continue
+
+            columns, referants = [], []
+            quote = self.identifier_preparer._maybe_quote_identifier
+
+            for row in fkey:
+                columns.append(normalize(row.COLUMNNAME))
+                if table.schema or not row.in_schema:
+                    referants.append('.'.join(
+                        [quote(normalize(row[c]))
+                         for c in ('REFSCHEMANAME', 'REFTABLENAME',
+                                   'REFCOLUMNNAME')]))
+                else:
+                    referants.append('.'.join(
+                        [quote(normalize(row[c]))
+                         for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
+
+            constraint_kw = {'name': fkeyname.lower()}
+            if fkey[0].RULE is not None:
+                rule = fkey[0].RULE
+                if rule.startswith('DELETE '):
+                    rule = rule[7:]
+                constraint_kw['ondelete'] = rule
+
+            table_kw = {}
+            if table.schema or not row.in_schema:
+                table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
+
+            ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
+                                            table_kw.get('schema'))
+            if ref_key not in table.metadata.tables:
+                schema.Table(normalize(fkey[0].REFTABLENAME),
+                             table.metadata,
+                             autoload=True, autoload_with=connection,
+                             **table_kw)
+
+            constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
+                                                     **constraint_kw)
+            table.append_constraint(constraint)
+
+    def has_sequence(self, connection, name):
+        # [ticket:726] makes this schema-aware.
+        denormalize = self.identifier_preparer._denormalize_name
+        sql = ("SELECT sequence_name FROM SEQUENCES "
+               "WHERE SEQUENCE_NAME=? ")
+
+        rp = connection.execute(sql, denormalize(name))
+        found = bool(rp.fetchone())
+        rp.close()
+        return found
+
 
 
 def _autoserial_column(table):
@@ -1090,10 +1052,3 @@ def _autoserial_column(table):
 
     return None, None
 
-dialect = MaxDBDialect
-dialect.preparer = MaxDBIdentifierPreparer
-dialect.statement_compiler = MaxDBCompiler
-dialect.schemagenerator = MaxDBSchemaGenerator
-dialect.schemadropper = MaxDBSchemaDropper
-dialect.defaultrunner = MaxDBDefaultRunner
-dialect.execution_ctx_cls = MaxDBExecutionContext
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/maxdb/sapdb.py b/lib/sqlalchemy/dialects/maxdb/sapdb.py
new file mode 100644 (file)
index 0000000..10e6122
--- /dev/null
@@ -0,0 +1,17 @@
+from sqlalchemy.dialects.maxdb.base import MaxDBDialect
+
+class MaxDB_sapdb(MaxDBDialect):
+    driver = 'sapdb'
+    
+    @classmethod
+    def dbapi(cls):
+        from sapdb import dbapi as _dbapi
+        return _dbapi
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        opts.update(url.query)
+        return [], opts
+
+
+dialect = MaxDB_sapdb
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py
new file mode 100644 (file)
index 0000000..e3a8290
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql
+
+base.dialect = pyodbc.dialect
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py
new file mode 100644 (file)
index 0000000..10b8b33
--- /dev/null
@@ -0,0 +1,51 @@
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
+import sys
+
+class MSDateTime_adodbapi(MSDateTime):
+    def result_processor(self, dialect):
+        def process(value):
+            # adodbapi will return datetimes with empty time values as datetime.date() objects.
+            # Promote them back to full datetime.datetime()
+            if type(value) is datetime.date:
+                return datetime.datetime(value.year, value.month, value.day)
+            return value
+        return process
+
+
+class MSDialect_adodbapi(MSDialect):
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = True
+    supports_unicode = sys.maxunicode == 65535
+    supports_unicode_statements = True
+    driver = 'adodbapi'
+    
+    @classmethod
+    def import_dbapi(cls):
+        import adodbapi as module
+        return module
+
+    colspecs = MSDialect.colspecs.copy()
+    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+
+    def create_connect_args(self, url):
+        keys = url.query
+
+        connectors = ["Provider=SQLOLEDB"]
+        if 'port' in keys:
+            connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
+        else:
+            connectors.append ("Data Source=%s" % keys.get("host"))
+        connectors.append ("Initial Catalog=%s" % keys.get("database"))
+        user = keys.get("user")
+        if user:
+            connectors.append("User Id=%s" % user)
+            connectors.append("Password=%s" % keys.get("password", ""))
+        else:
+            connectors.append("Integrated Security=SSPI")
+        return [[";".join (connectors)], {}]
+
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
+
+dialect = MSDialect_adodbapi
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
new file mode 100644 (file)
index 0000000..cd031af
--- /dev/null
@@ -0,0 +1,1448 @@
+# mssql.py
+
+"""Support for the Microsoft SQL Server database.
+
+Driver
+------
+
+The MSSQL dialect will work with three different available drivers:
+
+* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
+  driver.
+
+* *pymssql* - http://pymssql.sourceforge.net/
+
+* *adodbapi* - http://adodbapi.sourceforge.net/
+
+Drivers are loaded in the order listed above based on availability.
+
+If you need to load a specific driver pass ``module_name`` when
+creating the engine::
+
+    engine = create_engine('mssql+module_name://dsn')
+
+``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
+``adodbapi``.
+
+Currently the pyodbc driver offers the greatest level of
+compatibility.
+
+Connecting
+----------
+
+Connecting with create_engine() uses the standard URL approach of
+``mssql://user:pass@host/dbname[?key=value&key=value...]``.
+
+If the database name is present, the tokens are converted to a
+connection string with the specified values. If the database is not
+present, then the host token is taken directly as the DSN name.
+
+Examples of pyodbc connection string URLs:
+
+* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``.
+  The connection string that is created will appear like::
+
+    dsn=mydsn;TrustedConnection=Yes
+
+* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named
+  ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
+  connection string that is created will appear like::
+
+    dsn=mydsn;UID=user;PWD=pass
+
+* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects
+  using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
+  information, plus the additional connection configuration option
+  ``LANGUAGE``. The connection string that is created will appear
+  like::
+
+    dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
+
+* *mssql+pyodbc://user:pass@host/db* - connects using a connection string
+  dynamically created that would appear like::
+
+    DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
+
+* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection
+  string that is dynamically created, which also includes the port
+  information using the comma syntax. If your connection string
+  requires the port information to be passed as a ``port`` keyword
+  see the next example. This will create the following connection
+  string::
+
+    DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
+
+* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection
+  string that is dynamically created that includes the port
+  information as a separate ``port`` keyword. This will create the
+  following connection string::
+
+    DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
+
+If you require a connection string that is outside the options
+presented above, use the ``odbc_connect`` keyword to pass in a
+urlencoded connection string. What gets passed in will be urldecoded
+and passed directly.
+
+For example::
+
+    mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
+
+would create the following connection string::
+
+    dsn=mydsn;Database=db
+
+Encoding your connection string can be easily accomplished through
+the python shell. For example::
+
+    >>> import urllib
+    >>> urllib.quote_plus('dsn=mydsn;Database=db')
+    'dsn%3Dmydsn%3BDatabase%3Ddb'
+
+Additional arguments which may be specified either as query string
+arguments on the URL, or as keyword argument to
+:func:`~sqlalchemy.create_engine()` are:
+
+* *query_timeout* - allows you to override the default query timeout.
+  Defaults to ``None``. This is only supported on pymssql.
+
+* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
+  should be used in place of the non-scoped version @@IDENTITY.
+  Defaults to True.
+
+* *max_identifier_length* - allows you to se the maximum length of
+  identfiers supported by the database. Defaults to 128. For pymssql
+  the default is 30.
+
+* *schema_name* - use to set the schema name. Defaults to ``dbo``.
+
+Auto Increment Behavior
+-----------------------
+
+``IDENTITY`` columns are supported by using SQLAlchemy
+``schema.Sequence()`` objects. In other words::
+
+    Table('test', mss_engine,
+           Column('id', Integer,
+                  Sequence('blah',100,10), primary_key=True),
+           Column('name', String(20))
+         ).create()
+
+would yield::
+
+   CREATE TABLE test (
+     id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
+     name VARCHAR(20) NULL,
+     )
+
+Note that the ``start`` and ``increment`` values for sequences are
+optional and will default to 1,1.
+
+* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
+  ``INSERT`` s)
+
+* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on
+  ``INSERT``
+
+Collation Support
+-----------------
+
+MSSQL specific string types support a collation parameter that
+creates a column-level specific collation for the column. The
+collation parameter accepts a Windows Collation Name or a SQL
+Collation Name. Supported types are MSChar, MSNChar, MSString,
+MSNVarchar, MSText, and MSNText. For example::
+
+    Column('login', String(32, collation='Latin1_General_CI_AS'))
+
+will yield::
+
+    login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
+
+LIMIT/OFFSET Support
+--------------------
+
+MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is
+supported directly through the ``TOP`` Transact SQL keyword::
+
+    select.limit
+
+will yield::
+
+    SELECT TOP n
+
+If using SQL Server 2005 or above, LIMIT with OFFSET
+support is available through the ``ROW_NUMBER OVER`` construct. 
+For versions below 2005, LIMIT with OFFSET usage will fail.
+
+Nullability
+-----------
+MSSQL has support for three levels of column nullability. The default
+nullability allows nulls and is explicit in the CREATE TABLE
+construct::
+
+    name VARCHAR(20) NULL
+
+If ``nullable=None`` is specified then no specification is made. In
+other words the database's configured default is used. This will
+render::
+
+    name VARCHAR(20)
+
+If ``nullable`` is ``True`` or ``False`` then the column will be
+``NULL` or ``NOT NULL`` respectively.
+
+Date / Time Handling
+--------------------
+DATE and TIME are supported.   Bind parameters are converted
+to datetime.datetime() objects as required by most MSSQL drivers,
+and results are processed from strings if needed.
+The DATE and TIME types are not available for MSSQL 2005 and
+previous - if a server version below 2008 is detected, DDL
+for these types will be issued as DATETIME.
+
+Compatibility Levels
+--------------------
+MSSQL supports the notion of setting compatibility levels at the
+database level. This allows, for instance, to run a database that
+is compatibile with SQL2000 while running on a SQL2005 database
+server. ``server_version_info`` will always retrun the database
+server version information (in this case SQL2005) and not the
+compatibiility level information. Because of this, if running under
+a backwards compatibility mode SQAlchemy may attempt to use T-SQL
+statements that are unable to be parsed by the database server.
+
+Known Issues
+------------
+
+* No support for more than one ``IDENTITY`` column per table
+
+* pymssql has problems with binary and unicode data that this module
+  does **not** work around
+
+"""
+import datetime, decimal, inspect, operator, sys, re
+import itertools
+
+from sqlalchemy import sql, schema as sa_schema, exc, util
+from sqlalchemy.sql import select, compiler, expression, \
+                            operators as sql_operators, \
+                            functions as sql_functions, util as sql_util
+from sqlalchemy.engine import default, base, reflection
+from sqlalchemy import types as sqltypes
+from decimal import Decimal as _python_Decimal
+from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
+                                FLOAT, TIMESTAMP, DATETIME, DATE
+            
+
+from sqlalchemy.dialects.mssql import information_schema as ischema
+
+MS_2008_VERSION = (10,)
+MS_2005_VERSION = (9,)
+MS_2000_VERSION = (8,)
+
+RESERVED_WORDS = set(
+    ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization',
+     'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade',
+     'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce',
+     'collate', 'column', 'commit', 'compute', 'constraint', 'contains',
+     'containstable', 'continue', 'convert', 'create', 'cross', 'current',
+     'current_date', 'current_time', 'current_timestamp', 'current_user',
+     'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default',
+     'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double',
+     'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec',
+     'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor',
+     'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full',
+     'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity',
+     'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert',
+     'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like',
+     'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not',
+     'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource',
+     'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer',
+     'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print',
+     'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext',
+     'reconfigure', 'references', 'replication', 'restore', 'restrict',
+     'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount',
+     'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select',
+     'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics',
+     'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top',
+     'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union',
+     'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
+     'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
+     'writetext',
+    ])
+
+
+class _MSNumeric(sqltypes.Numeric):
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            def process(value):
+                if value is not None:
+                    return _python_Decimal(str(value))
+                else:
+                    return value
+            return process
+        else:
+            def process(value):
+                return float(value)
+            return process
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if isinstance(value, decimal.Decimal):
+                if value.adjusted() < 0:
+                    result = "%s0.%s%s" % (
+                            (value < 0 and '-' or ''),
+                            '0' * (abs(value.adjusted()) - 1),
+                            "".join([str(nint) for nint in value._int]))
+
+                else:
+                    if 'E' in str(value):
+                        result = "%s%s%s" % (
+                                (value < 0 and '-' or ''),
+                                "".join([str(s) for s in value._int]),
+                                "0" * (value.adjusted() - (len(value._int)-1)))
+                    else:
+                        if (len(value._int) - 1) > value.adjusted():
+                            result = "%s%s.%s" % (
+                                    (value < 0 and '-' or ''),
+                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
+                                    "".join([str(s) for s in value._int][value.adjusted() + 1:]))
+                        else:
+                            result = "%s%s" % (
+                                    (value < 0 and '-' or ''),
+                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
+
+                return result
+
+            else:
+                return value
+
+        return process
+
+class REAL(sqltypes.Float):
+    """A type for ``real`` numbers."""
+
+    __visit_name__ = 'REAL'
+
+    def __init__(self):
+        super(REAL, self).__init__(precision=24)
+
+class TINYINT(sqltypes.Integer):
+    __visit_name__ = 'TINYINT'
+
+
+# MSSQL DATE/TIME types have varied behavior, sometimes returning
+# strings.  MSDate/TIME check for everything, and always
+# filter bind parameters into datetime objects (required by pyodbc,
+# not sure about other dialects).
+
+class _MSDate(sqltypes.Date):
+    def bind_processor(self, dialect):
+        def process(value):
+            if type(value) == datetime.date:
+                return datetime.datetime(value.year, value.month, value.day)
+            else:
+                return value
+        return process
+
+    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+    def result_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                return value.date()
+            elif isinstance(value, basestring):
+                return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()])
+            else:
+                return value
+        return process
+
+class TIME(sqltypes.TIME):
+    def __init__(self, precision=None, **kwargs):
+        self.precision = precision
+        super(TIME, self).__init__()
+
+    __zero_date = datetime.date(1900, 1, 1)
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                value = datetime.datetime.combine(self.__zero_date, value.time())
+            elif isinstance(value, datetime.time):
+                value = datetime.datetime.combine(self.__zero_date, value)
+            return value
+        return process
+
+    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+    def result_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                return value.time()
+            elif isinstance(value, basestring):
+                return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()])
+            else:
+                return value
+        return process
+
+
+class _DateTimeBase(object):
+    def bind_processor(self, dialect):
+        def process(value):
+            # TODO: why ?
+            if type(value) == datetime.date:
+                return datetime.datetime(value.year, value.month, value.day)
+            else:
+                return value
+        return process
+
+class _MSDateTime(_DateTimeBase, sqltypes.DateTime):
+    pass
+
+class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
+    __visit_name__ = 'SMALLDATETIME'
+
+class DATETIME2(_DateTimeBase, sqltypes.DateTime):
+    __visit_name__ = 'DATETIME2'
+    
+    def __init__(self, precision=None, **kwargs):
+        self.precision = precision
+
+
+# TODO: is this not an Interval ?
+class DATETIMEOFFSET(sqltypes.TypeEngine):
+    __visit_name__ = 'DATETIMEOFFSET'
+    
+    def __init__(self, precision=None, **kwargs):
+        self.precision = precision
+
+
+class _StringType(object):
+    """Base for MSSQL string types."""
+
+    def __init__(self, collation=None):
+        self.collation = collation
+
+    def __repr__(self):
+        attributes = inspect.getargspec(self.__init__)[0][1:]
+        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
+
+        params = {}
+        for attr in attributes:
+            val = getattr(self, attr)
+            if val is not None and val is not False:
+                params[attr] = val
+
+        return "%s(%s)" % (self.__class__.__name__,
+                           ', '.join(['%s=%r' % (k, params[k]) for k in params]))
+
+
+class TEXT(_StringType, sqltypes.TEXT):
+    """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
+
+    def __init__(self, *args, **kw):
+        """Construct a TEXT.
+
+        :param collation: Optional, a column-level collation for this string
+          value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+        """
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.Text.__init__(self, *args, **kw)
+
+class NTEXT(_StringType, sqltypes.UnicodeText):
+    """MSSQL NTEXT type, for variable-length unicode text up to 2^30
+    characters."""
+
+    __visit_name__ = 'NTEXT'
+    
+    def __init__(self, *args, **kwargs):
+        """Construct a NTEXT.
+
+        :param collation: Optional, a column-level collation for this string
+          value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+        """
+        collation = kwargs.pop('collation', None)
+        _StringType.__init__(self, collation)
+        length = kwargs.pop('length', None)
+        sqltypes.UnicodeText.__init__(self, length, **kwargs)
+
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
+    """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
+    of 8,000 characters."""
+
+    def __init__(self, *args, **kw):
+        """Construct a VARCHAR.
+
+        :param length: Optinal, maximum data length, in characters.
+
+        :param convert_unicode: defaults to False.  If True, convert
+          ``unicode`` data sent to the database to a ``str``
+          bytestring, and convert bytestrings coming back from the
+          database into ``unicode``.
+
+          Bytestrings are encoded using the dialect's
+          :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
+          defaults to `utf-8`.
+
+          If False, may be overridden by
+          :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
+
+        :param assert_unicode:
+
+          If None (the default), no assertion will take place unless
+          overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
+
+          If 'warn', will issue a runtime warning if a ``str``
+          instance is used as a bind value.
+
+          If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
+
+        :param collation: Optional, a column-level collation for this string
+          value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+        """
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.VARCHAR.__init__(self, *args, **kw)
+
+class NVARCHAR(_StringType, sqltypes.NVARCHAR):
+    """MSSQL NVARCHAR type.
+
+    For variable-length unicode character data up to 4,000 characters."""
+
+    def __init__(self, *args, **kw):
+        """Construct a NVARCHAR.
+
+        :param length: Optional, Maximum data length, in characters.
+
+        :param collation: Optional, a column-level collation for this string
+          value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+        """
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.NVARCHAR.__init__(self, *args, **kw)
+
+class CHAR(_StringType, sqltypes.CHAR):
+    """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
+    of 8,000 characters."""
+
+    def __init__(self, *args, **kw):
+        """Construct a CHAR.
+
+        :param length: Optinal, maximum data length, in characters.
+
+        :param convert_unicode: defaults to False.  If True, convert
+          ``unicode`` data sent to the database to a ``str``
+          bytestring, and convert bytestrings coming back from the
+          database into ``unicode``.
+
+          Bytestrings are encoded using the dialect's
+          :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
+          defaults to `utf-8`.
+
+          If False, may be overridden by
+          :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
+
+        :param assert_unicode:
+
+          If None (the default), no assertion will take place unless
+          overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
+
+          If 'warn', will issue a runtime warning if a ``str``
+          instance is used as a bind value.
+
+          If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
+
+        :param collation: Optional, a column-level collation for this string
+          value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+        """
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.CHAR.__init__(self, *args, **kw)
+
+class NCHAR(_StringType, sqltypes.NCHAR):
+    """MSSQL NCHAR type.
+
+    For fixed-length unicode character data up to 4,000 characters."""
+
+    def __init__(self, *args, **kw):
+        """Construct an NCHAR.
+
+        :param length: Optional, Maximum data length, in characters.
+
+        :param collation: Optional, a column-level collation for this string
+          value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+        """
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.NCHAR.__init__(self, *args, **kw)
+
+class BINARY(sqltypes.Binary):
+    __visit_name__ = 'BINARY'
+
+class VARBINARY(sqltypes.Binary):
+    __visit_name__ = 'VARBINARY'
+        
+class IMAGE(sqltypes.Binary):
+    __visit_name__ = 'IMAGE'
+
+class BIT(sqltypes.TypeEngine):
+    __visit_name__ = 'BIT'
+    
+class _MSBoolean(sqltypes.Boolean):
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+
+class MONEY(sqltypes.TypeEngine):
+    __visit_name__ = 'MONEY'
+
+class SMALLMONEY(sqltypes.TypeEngine):
+    __visit_name__ = 'SMALLMONEY'
+
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
+    __visit_name__ = "UNIQUEIDENTIFIER"
+
+class SQL_VARIANT(sqltypes.TypeEngine):
+    __visit_name__ = 'SQL_VARIANT'
+
+# old names.
+MSNumeric = _MSNumeric
+MSDateTime = _MSDateTime
+MSDate = _MSDate
+MSBoolean = _MSBoolean
+MSReal = REAL
+MSTinyInteger = TINYINT
+MSTime = TIME
+MSSmallDateTime = SMALLDATETIME
+MSDateTime2 = DATETIME2
+MSDateTimeOffset = DATETIMEOFFSET
+MSText = TEXT
+MSNText = NTEXT
+MSString = VARCHAR
+MSNVarchar = NVARCHAR
+MSChar = CHAR
+MSNChar = NCHAR
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSImage = IMAGE
+MSBit = BIT
+MSMoney = MONEY
+MSSmallMoney = SMALLMONEY
+MSUniqueIdentifier = UNIQUEIDENTIFIER
+MSVariant = SQL_VARIANT
+
+colspecs = {
+    sqltypes.Numeric : _MSNumeric,
+    sqltypes.DateTime : _MSDateTime,
+    sqltypes.Date : _MSDate,
+    sqltypes.Time : TIME,
+    sqltypes.Boolean : _MSBoolean,
+}
+
+ischema_names = {
+    'int' : INTEGER,
+    'bigint': BIGINT,
+    'smallint' : SMALLINT,
+    'tinyint' : TINYINT,
+    'varchar' : VARCHAR,
+    'nvarchar' : NVARCHAR,
+    'char' : CHAR,
+    'nchar' : NCHAR,
+    'text' : TEXT,
+    'ntext' : NTEXT,
+    'decimal' : DECIMAL,
+    'numeric' : NUMERIC,
+    'float' : FLOAT,
+    'datetime' : DATETIME,
+    'datetime2' : DATETIME2,
+    'datetimeoffset' : DATETIMEOFFSET,
+    'date': DATE,
+    'time': TIME,
+    'smalldatetime' : SMALLDATETIME,
+    'binary' : BINARY,
+    'varbinary' : VARBINARY,
+    'bit': BIT,
+    'real' : REAL,
+    'image' : IMAGE,
+    'timestamp': TIMESTAMP,
+    'money': MONEY,
+    'smallmoney': SMALLMONEY,
+    'uniqueidentifier': UNIQUEIDENTIFIER,
+    'sql_variant': SQL_VARIANT,
+}
+
+
+class MSTypeCompiler(compiler.GenericTypeCompiler):
+    def _extend(self, spec, type_):
+        """Extend a string-type declaration with standard SQL
+        COLLATE annotations.
+
+        """
+
+        if getattr(type_, 'collation', None):
+            collation = 'COLLATE %s' % type_.collation
+        else:
+            collation = None
+
+        if type_.length:
+            spec = spec + "(%d)" % type_.length
+        
+        return ' '.join([c for c in (spec, collation)
+            if c is not None])
+
+    def visit_FLOAT(self, type_):
+        precision = getattr(type_, 'precision', None)
+        if precision is None:
+            return "FLOAT"
+        else:
+            return "FLOAT(%(precision)s)" % {'precision': precision}
+
+    def visit_REAL(self, type_):
+        return "REAL"
+
+    def visit_TINYINT(self, type_):
+        return "TINYINT"
+
+    def visit_DATETIMEOFFSET(self, type_):
+        if type_.precision:
+            return "DATETIMEOFFSET(%s)" % type_.precision
+        else:
+            return "DATETIMEOFFSET"
+
+    def visit_TIME(self, type_):
+        precision = getattr(type_, 'precision', None)
+        if precision:
+            return "TIME(%s)" % precision
+        else:
+            return "TIME"
+
+    def visit_DATETIME2(self, type_):
+        precision = getattr(type_, 'precision', None)
+        if precision:
+            return "DATETIME2(%s)" % precision
+        else:
+            return "DATETIME2"
+
+    def visit_SMALLDATETIME(self, type_):
+        return "SMALLDATETIME"
+
+    def visit_unicode(self, type_):
+        return self.visit_NVARCHAR(type_)
+        
+    def visit_unicode_text(self, type_):
+        return self.visit_NTEXT(type_)
+        
+    def visit_NTEXT(self, type_):
+        return self._extend("NTEXT", type_)
+
+    def visit_TEXT(self, type_):
+        return self._extend("TEXT", type_)
+
+    def visit_VARCHAR(self, type_):
+        return self._extend("VARCHAR", type_)
+
+    def visit_CHAR(self, type_):
+        return self._extend("CHAR", type_)
+
+    def visit_NCHAR(self, type_):
+        return self._extend("NCHAR", type_)
+
+    def visit_NVARCHAR(self, type_):
+        return self._extend("NVARCHAR", type_)
+
+    def visit_date(self, type_):
+        if self.dialect.server_version_info < MS_2008_VERSION:
+            return self.visit_DATETIME(type_)
+        else:
+            return self.visit_DATE(type_)
+
+    def visit_time(self, type_):
+        if self.dialect.server_version_info < MS_2008_VERSION:
+            return self.visit_DATETIME(type_)
+        else:
+            return self.visit_TIME(type_)
+            
+    def visit_binary(self, type_):
+        if type_.length:
+            return self.visit_BINARY(type_)
+        else:
+            return self.visit_IMAGE(type_)
+
+    def visit_BINARY(self, type_):
+        if type_.length:
+            return "BINARY(%s)" % type_.length
+        else:
+            return "BINARY"
+
+    def visit_IMAGE(self, type_):
+        return "IMAGE"
+
+    def visit_VARBINARY(self, type_):
+        if type_.length:
+            return "VARBINARY(%s)" % type_.length
+        else:
+            return "VARBINARY"
+
+    def visit_boolean(self, type_):
+        return self.visit_BIT(type_)
+
+    def visit_BIT(self, type_):
+        return "BIT"
+
+    def visit_MONEY(self, type_):
+        return "MONEY"
+
+    def visit_SMALLMONEY(self, type_):
+        return 'SMALLMONEY'
+
+    def visit_UNIQUEIDENTIFIER(self, type_):
+        return "UNIQUEIDENTIFIER"
+
+    def visit_SQL_VARIANT(self, type_):
+        return 'SQL_VARIANT'
+
+class MSExecutionContext(default.DefaultExecutionContext):
+    _enable_identity_insert = False
+    _select_lastrowid = False
+    _result_proxy = None
+    _lastrowid = None
+    
+    def pre_exec(self):
+        """Activate IDENTITY_INSERT if needed."""
+
+        if self.isinsert:
+            tbl = self.compiled.statement.table
+            seq_column = tbl._autoincrement_column
+            insert_has_sequence = seq_column is not None
+            
+            if insert_has_sequence:
+                self._enable_identity_insert = seq_column.key in self.compiled_parameters[0]
+            else:
+                self._enable_identity_insert = False
+            
+            self._select_lastrowid = insert_has_sequence and \
+                                        not self.compiled.returning and \
+                                        not self._enable_identity_insert and \
+                                        not self.executemany
+            
+            if self._enable_identity_insert:
+                self.cursor.execute("SET IDENTITY_INSERT %s ON" % 
+                    self.dialect.identifier_preparer.format_table(tbl))
+
+    def post_exec(self):
+        """Disable IDENTITY_INSERT if enabled."""
+        
+        if self._select_lastrowid:
+            if self.dialect.use_scope_identity:
+                self.cursor.execute("SELECT scope_identity() AS lastrowid")
+            else:
+                self.cursor.execute("SELECT @@identity AS lastrowid")
+            row = self.cursor.fetchall()[0]   # fetchall() ensures the cursor is consumed without closing it
+            self._lastrowid = int(row[0])
+
+        if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning:
+            self._result_proxy = base.FullyBufferedResultProxy(self)
+            
+        if self._enable_identity_insert:
+            self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+        
+    def get_lastrowid(self):
+        return self._lastrowid
+        
+    def handle_dbapi_exception(self, e):
+        if self._enable_identity_insert:
+            try:
+                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+            except:
+                pass
+
+    def get_result_proxy(self):
+        if self._result_proxy:
+            return self._result_proxy
+        else:
+            return base.ResultProxy(self)
+
+class MSSQLCompiler(compiler.SQLCompiler):
+
+    extract_map = compiler.SQLCompiler.extract_map.copy()
+    extract_map.update ({
+        'doy': 'dayofyear',
+        'dow': 'weekday',
+        'milliseconds': 'millisecond',
+        'microseconds': 'microsecond'
+    })
+
+    def __init__(self, *args, **kwargs):
+        super(MSSQLCompiler, self).__init__(*args, **kwargs)
+        self.tablealiases = {}
+
+    def visit_now_func(self, fn, **kw):
+        return "CURRENT_TIMESTAMP"
+        
+    def visit_current_date_func(self, fn, **kw):
+        return "GETDATE()"
+        
+    def visit_length_func(self, fn, **kw):
+        return "LEN%s" % self.function_argspec(fn, **kw)
+        
+    def visit_char_length_func(self, fn, **kw):
+        return "LEN%s" % self.function_argspec(fn, **kw)
+        
+    def visit_concat_op(self, binary):
+        return "%s + %s" % (self.process(binary.left), self.process(binary.right))
+        
+    def visit_match_op(self, binary):
+        return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
+    def get_select_precolumns(self, select):
+        """ MS-SQL puts TOP, it's version of LIMIT here """
+        if select._distinct or select._limit:
+            s = select._distinct and "DISTINCT " or ""
+            
+            if select._limit:
+                if not select._offset:
+                    s += "TOP %s " % (select._limit,)
+            return s
+        return compiler.SQLCompiler.get_select_precolumns(self, select)
+
+    def limit_clause(self, select):
+        # Limit in mssql is after the select keyword
+        return ""
+
+    def visit_select(self, select, **kwargs):
+        """Look for ``LIMIT`` and OFFSET in a select statement, and if
+        so tries to wrap it in a subquery with ``row_number()`` criterion.
+
+        """
+        if not getattr(select, '_mssql_visit', None) and select._offset:
+            # to use ROW_NUMBER(), an ORDER BY is required.
+            orderby = self.process(select._order_by_clause)
+            if not orderby:
+                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
+
+            _offset = select._offset
+            _limit = select._limit
+            select._mssql_visit = True
+            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
+
+            limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
+            limitselect.append_whereclause("mssql_rn>%d" % _offset)
+            if _limit is not None:
+                limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
+            return self.process(limitselect, iswrapper=True, **kwargs)
+        else:
+            return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+    def _schema_aliased_table(self, table):
+        if getattr(table, 'schema', None) is not None:
+            if table not in self.tablealiases:
+                self.tablealiases[table] = table.alias()
+            return self.tablealiases[table]
+        else:
+            return None
+
+    def visit_table(self, table, mssql_aliased=False, **kwargs):
+        if mssql_aliased:
+            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+        # alias schema-qualified tables
+        alias = self._schema_aliased_table(table)
+        if alias is not None:
+            return self.process(alias, mssql_aliased=True, **kwargs)
+        else:
+            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+    def visit_alias(self, alias, **kwargs):
+        # translate for schema-qualified table aliases
+        self.tablealiases[alias.original] = alias
+        kwargs['mssql_aliased'] = True
+        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
+
+    def visit_extract(self, extract):
+        field = self.extract_map.get(extract.field, extract.field)
+        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
+    def visit_rollback_to_savepoint(self, savepoint_stmt):
+        return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+    def visit_column(self, column, result_map=None, **kwargs):
+        if column.table is not None and \
+            (not self.isupdate and not self.isdelete) or self.is_subquery():
+            # translate for schema-qualified table aliases
+            t = self._schema_aliased_table(column.table)
+            if t is not None:
+                converted = expression._corresponding_column_or_error(t, column)
+
+                if result_map is not None:
+                    result_map[column.name.lower()] = (column.name, (column, ), column.type)
+
+                return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
+
+        return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
+
+    def visit_binary(self, binary, **kwargs):
+        """Move bind parameters to the right-hand side of an operator, where
+        possible.
+
+        """
+        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
+            and not isinstance(binary.right, expression._BindParamClause):
+            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
+        else:
+            if (binary.operator is operator.eq or binary.operator is operator.ne) and (
+                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
+                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
+                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
+                op = binary.operator == operator.eq and "IN" or "NOT IN"
+                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
+            return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
+
+    def returning_clause(self, stmt, returning_cols):
+
+        if self.isinsert or self.isupdate:
+            target = stmt.table.alias("inserted")
+        else:
+            target = stmt.table.alias("deleted")
+        
+        adapter = sql_util.ClauseAdapter(target)
+        def col_label(col):
+            adapted = adapter.traverse(c)
+            if isinstance(c, expression._Label):
+                return adapted.label(c.key)
+            else:
+                return self.label_select_column(None, adapted, asfrom=False)
+            
+        columns = [
+            self.process(
+                col_label(c), 
+                within_columns_clause=True, 
+                result_map=self.result_map
+            ) 
+            for c in expression._select_iterables(returning_cols)
+        ]
+        return 'OUTPUT ' + ', '.join(columns)
+
+    def label_select_column(self, select, column, asfrom):
+        if isinstance(column, expression.Function):
+            return column.label(None)
+        else:
+            return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
+
+    def for_update_clause(self, select):
+        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+        return ''
+
+    def order_by_clause(self, select):
+        order_by = self.process(select._order_by_clause)
+
+        # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
+        if order_by and (not self.is_subquery() or select._limit):
+            return " ORDER BY " + order_by
+        else:
+            return ""
+
+
+class MSDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+
+        if column.nullable is not None:
+            if not column.nullable or column.primary_key:
+                colspec += " NOT NULL"
+            else:
+                colspec += " NULL"
+        
+        if not column.table:
+            raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
+            
+        seq_col = column.table._autoincrement_column
+
+        # install a IDENTITY Sequence if we have an implicit IDENTITY column
+        if seq_col is column:
+            sequence = getattr(column, 'sequence', None)
+            if sequence:
+                start, increment = sequence.start or 1, sequence.increment or 1
+            else:
+                start, increment = 1, 1
+            colspec += " IDENTITY(%s,%s)" % (start, increment)
+        else:
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
+        return colspec
+
+    def visit_drop_index(self, drop):
+        return "\nDROP INDEX %s.%s" % (
+            self.preparer.quote_identifier(drop.element.table.name),
+            self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote)
+            )
+
+
+class MSIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = RESERVED_WORDS
+
+    def __init__(self, dialect):
+        super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
+
+    def _escape_identifier(self, value):
+        #TODO: determine MSSQL's escaping rules
+        return value
+
+    def quote_schema(self, schema, force=True):
+        """Prepare a quoted table and schema name."""
+        result = '.'.join([self.quote(x, force) for x in schema.split('.')])
+        return result
+
+class MSDialect(default.DefaultDialect):
+    name = 'mssql'
+    supports_default_values = True
+    supports_empty_insert = False
+    execution_ctx_cls = MSExecutionContext
+    text_as_varchar = False
+    use_scope_identity = True
+    max_identifier_length = 128
+    schema_name = "dbo"
+    colspecs = colspecs
+    ischema_names = ischema_names
+    
+    supports_unicode_binds = True
+    postfetch_lastrowid = True
+    
+    server_version_info = ()
+    
+    statement_compiler = MSSQLCompiler
+    ddl_compiler = MSDDLCompiler
+    type_compiler = MSTypeCompiler
+    preparer = MSIdentifierPreparer
+
+    def __init__(self,
+                 query_timeout=None,
+                 use_scope_identity=True,
+                 max_identifier_length=None,
+                 schema_name="dbo", **opts):
+        self.query_timeout = int(query_timeout or 0)
+        self.schema_name = schema_name
+
+        self.use_scope_identity = use_scope_identity
+        self.max_identifier_length = int(max_identifier_length or 0) or \
+                self.max_identifier_length
+        super(MSDialect, self).__init__(**opts)
+    
+    def do_savepoint(self, connection, name):
+        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+        connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
+        connection.execute("SAVE TRANSACTION %s" % name)
+
+    def do_release_savepoint(self, connection, name):
+        pass
+    
+    def initialize(self, connection):
+        super(MSDialect, self).initialize(connection)
+        if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__:
+            self.implicit_returning = True
+        
+    def get_default_schema_name(self, connection):
+        return self.default_schema_name
+        
+    def _get_default_schema_name(self, connection):
+        user_name = connection.scalar("SELECT user_name() as user_name;")
+        if user_name is not None:
+            # now, get the default schema
+            query = """
+            SELECT default_schema_name FROM
+            sys.database_principals
+            WHERE name = ?
+            AND type = 'S'
+            """
+            try:
+                default_schema_name = connection.scalar(query, [user_name])
+                if default_schema_name is not None:
+                    return default_schema_name
+            except:
+                pass
+        return self.schema_name
+
+    def table_names(self, connection, schema):
+        s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema)
+        return [row[0] for row in connection.execute(s)]
+
+
+    def has_table(self, connection, tablename, schema=None):
+        current_schema = schema or self.default_schema_name
+        columns = ischema.columns
+        s = sql.select([columns],
+                   current_schema
+                       and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
+                       or columns.c.table_name==tablename,
+                   )
+
+        c = connection.execute(s)
+        row  = c.fetchone()
+        return row is not None
+
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
+        s = sql.select([ischema.schemata.c.schema_name],
+            order_by=[ischema.schemata.c.schema_name]
+        )
+        schema_names = [r[0] for r in connection.execute(s)]
+        return schema_names
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        current_schema = schema or self.default_schema_name
+        tables = ischema.tables
+        s = sql.select([tables.c.table_name],
+            sql.and_(
+                tables.c.table_schema == current_schema,
+                tables.c.table_type == 'BASE TABLE'
+            ),
+            order_by=[tables.c.table_name]
+        )
+        table_names = [r[0] for r in connection.execute(s)]
+        return table_names
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        current_schema = schema or self.default_schema_name
+        tables = ischema.tables
+        s = sql.select([tables.c.table_name],
+            sql.and_(
+                tables.c.table_schema == current_schema,
+                tables.c.table_type == 'VIEW'
+            ),
+            order_by=[tables.c.table_name]
+        )
+        view_names = [r[0] for r in connection.execute(s)]
+        return view_names
+
+    # The cursor reports it is closed after executing the sp.
+    @reflection.cache
+    def get_indexes(self, connection, tablename, schema=None, **kw):
+        current_schema = schema or self.default_schema_name
+        full_tname = "%s.%s" % (current_schema, tablename)
+        indexes = []
+        s = sql.text("exec sp_helpindex '%s'" % full_tname)
+        rp = connection.execute(s)
+        if rp.closed:
+            # did not work for this setup.
+            return []
+        for row in rp:
+            if 'primary key' not in row['index_description']:
+                indexes.append({
+                    'name' : row['index_name'],
+                    'column_names' : [c.strip() for c in row['index_keys'].split(',')],
+                    'unique': 'unique' in row['index_description']
+                })
+        return indexes
+
+    @reflection.cache
+    def get_view_definition(self, connection, viewname, schema=None, **kw):
+        current_schema = schema or self.default_schema_name
+        views = ischema.views
+        s = sql.select([views.c.view_definition],
+            sql.and_(
+                views.c.table_schema == current_schema,
+                views.c.table_name == viewname
+            ),
+        )
+        rp = connection.execute(s)
+        if rp:
+            view_def = rp.scalar()
+            return view_def
+
+    @reflection.cache
+    def get_columns(self, connection, tablename, schema=None, **kw):
+        # Get base columns
+        current_schema = schema or self.default_schema_name
+        columns = ischema.columns
+        s = sql.select([columns],
+                   current_schema
+                       and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
+                       or columns.c.table_name==tablename,
+                   order_by=[columns.c.ordinal_position])
+        c = connection.execute(s)
+        cols = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            (name, type, nullable, charlen, numericprec, numericscale, default, collation) = (
+                row[columns.c.column_name],
+                row[columns.c.data_type],
+                row[columns.c.is_nullable] == 'YES',
+                row[columns.c.character_maximum_length],
+                row[columns.c.numeric_precision],
+                row[columns.c.numeric_scale],
+                row[columns.c.column_default],
+                row[columns.c.collation_name]
+            )
+            coltype = self.ischema_names.get(type, None)
+
+            kwargs = {}
+            if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary):
+                kwargs['length'] = charlen
+                if collation:
+                    kwargs['collation'] = collation
+                if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1):
+                    kwargs.pop('length')
+
+            if coltype is None:
+                util.warn("Did not recognize type '%s' of column '%s'" % (type, name))
+                coltype = sqltypes.NULLTYPE
+
+            if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal:
+                kwargs['scale'] = numericscale
+                kwargs['precision'] = numericprec
+
+            coltype = coltype(**kwargs)
+            cdict = {
+                'name' : name,
+                'type' : coltype,
+                'nullable' : nullable,
+                'default' : default,
+                'autoincrement':False,
+            }
+            cols.append(cdict)
+        # autoincrement and identity
+        colmap = {}
+        for col in cols:
+            colmap[col['name']] = col
+        # We also run an sp_columns to check for identity columns:
+        cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema))
+        ic = None
+        while True:
+            row = cursor.fetchone()
+            if row is None:
+                break
+            (col_name, type_name) = row[3], row[5]
+            if type_name.endswith("identity") and col_name in colmap:
+                ic = col_name
+                colmap[col_name]['autoincrement'] = True
+                colmap[col_name]['sequence'] = dict(
+                                    name='%s_identity' % col_name)
+                break
+        cursor.close()
+        if ic is not None:
+            try:
+                # is this table_fullname reliable?
+                table_fullname = "%s.%s" % (current_schema, tablename)
+                cursor = connection.execute(
+                    sql.text("select ident_seed(:seed), ident_incr(:incr)"), 
+                    {'seed':table_fullname, 'incr':table_fullname}
+                )
+                row = cursor.fetchone()
+                cursor.close()
+                if not row is None:
+                    colmap[ic]['sequence'].update({
+                        'start' : int(row[0]),
+                        'increment' : int(row[1])
+                    })
+            except:
+                # ignoring it, works just like before
+                pass
+        return cols
+
+    @reflection.cache
+    def get_primary_keys(self, connection, tablename, schema=None, **kw):
+        current_schema = schema or self.default_schema_name
+        pkeys = []
+        # Add constraints
+        RR = ischema.ref_constraints    #information_schema.referential_constraints
+        TC = ischema.constraints        #information_schema.table_constraints
+        C  = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column
+        R  = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
+
+        # Primary key constraints
+        s = sql.select([C.c.column_name, TC.c.constraint_type],
+            sql.and_(TC.c.constraint_name == C.c.constraint_name,
+                     C.c.table_name == tablename,
+                     C.c.table_schema == current_schema)
+        )
+        c = connection.execute(s)
+        for row in c:
+            if 'PRIMARY' in row[TC.c.constraint_type.name]:
+                pkeys.append(row[0])
+        return pkeys
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, tablename, schema=None, **kw):
+        current_schema = schema or self.default_schema_name
+        # Add constraints
+        RR = ischema.ref_constraints    #information_schema.referential_constraints
+        TC = ischema.constraints        #information_schema.table_constraints
+        C  = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column
+        R  = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
+
+        # Foreign key constraints
+        s = sql.select([C.c.column_name,
+                        R.c.table_schema, R.c.table_name, R.c.column_name,
+                        RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
+                       sql.and_(C.c.table_name == tablename,
+                                C.c.table_schema == current_schema,
+                                C.c.constraint_name == RR.c.constraint_name,
+                                R.c.constraint_name == RR.c.unique_constraint_name,
+                                C.c.ordinal_position == R.c.ordinal_position
+                                ),
+                       order_by = [RR.c.constraint_name, R.c.ordinal_position])
+        
+
+        # group rows by constraint ID, to handle multi-column FKs
+        fkeys = []
+        fknm, scols, rcols = (None, [], [])
+        
+        def fkey_rec():
+            return {
+                'name' : None,
+                'constrained_columns' : [],
+                'referred_schema' : None,
+                'referred_table' : None,
+                'referred_columns' : []
+            }
+
+        fkeys = util.defaultdict(fkey_rec)
+        
+        for r in connection.execute(s).fetchall():
+            scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
+
+            rec = fkeys[rfknm]
+            rec['name'] = rfknm
+            if not rec['referred_table']:
+                rec['referred_table'] = rtbl
+
+                if schema is not None or current_schema != rschema:
+                    rec['referred_schema'] = rschema
+            
+            local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+            
+            local_cols.append(scol)
+            remote_cols.append(rcol)
+
+        return fkeys.values()
+
+
+# fixme.  I added this for the tests to run. -Randall
+MSSQLDialect = MSDialect
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
new file mode 100644 (file)
index 0000000..bb6ff31
--- /dev/null
@@ -0,0 +1,83 @@
+from sqlalchemy import Table, MetaData, Column, ForeignKey
+from sqlalchemy.types import String, Unicode, Integer, TypeDecorator
+
+ischema = MetaData()
+
+class CoerceUnicode(TypeDecorator):
+    impl = Unicode
+    
+    def process_bind_param(self, value, dialect):
+        if isinstance(value, str):
+            value = value.decode(dialect.encoding)
+        return value
+    
+schemata = Table("SCHEMATA", ischema,
+    Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
+    Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
+    Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
+    schema="INFORMATION_SCHEMA")
+
+tables = Table("TABLES", ischema,
+    Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("TABLE_TYPE", String, key="table_type"),
+    schema="INFORMATION_SCHEMA")
+
+columns = Table("COLUMNS", ischema,
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+    Column("IS_NULLABLE", Integer, key="is_nullable"),
+    Column("DATA_TYPE", String, key="data_type"),
+    Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+    Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
+    Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+    Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+    Column("COLUMN_DEFAULT", Integer, key="column_default"),
+    Column("COLLATION_NAME", String, key="collation_name"),
+    schema="INFORMATION_SCHEMA")
+
+constraints = Table("TABLE_CONSTRAINTS", ischema,
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+    Column("CONSTRAINT_TYPE", String, key="constraint_type"),
+    schema="INFORMATION_SCHEMA")
+
+column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+    Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+    schema="INFORMATION_SCHEMA")
+
+key_constraints = Table("KEY_COLUMN_USAGE", ischema,
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+    Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+    Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+    schema="INFORMATION_SCHEMA")
+
+ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
+    Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
+    Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+    Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+    Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"),  # TODO: is CATLOG misspelled ?
+    Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"),
+    Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"),
+    Column("MATCH_OPTION", String, key="match_option"),
+    Column("UPDATE_RULE", String, key="update_rule"),
+    Column("DELETE_RULE", String, key="delete_rule"),
+    schema="INFORMATION_SCHEMA")
+
+views = Table("VIEWS", ischema,
+    Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
+    Column("CHECK_OPTION", String, key="check_option"),
+    Column("IS_UPDATABLE", String, key="is_updatable"),
+    schema="INFORMATION_SCHEMA")
+
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
new file mode 100644 (file)
index 0000000..0961c2e
--- /dev/null
@@ -0,0 +1,46 @@
+from sqlalchemy.dialects.mssql.base import MSDialect
+from sqlalchemy import types as sqltypes
+
+
+class MSDialect_pymssql(MSDialect):
+    supports_sane_rowcount = False
+    max_identifier_length = 30
+    driver = 'pymssql'
+
+    @classmethod
+    def import_dbapi(cls):
+        import pymssql as module
+        # pymmsql doesn't have a Binary method.  we use string
+        # TODO: monkeypatching here is less than ideal
+        module.Binary = lambda st: str(st)
+        return module
+
+    def __init__(self, **params):
+        super(MSSQLDialect_pymssql, self).__init__(**params)
+        self.use_scope_identity = True
+
+        # pymssql understands only ascii
+        if self.convert_unicode:
+            util.warn("pymssql does not support unicode")
+            self.encoding = params.get('encoding', 'ascii')
+
+
+    def create_connect_args(self, url):
+        if hasattr(self, 'query_timeout'):
+            # ick, globals ?   we might want to move this....
+            self.dbapi._mssql.set_query_timeout(self.query_timeout)
+
+        keys = url.query
+        if keys.get('port'):
+            # pymssql expects port as host:port, not a separate arg
+            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
+            del keys['port']
+        return [[], keys]
+
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
+
+    def do_begin(self, connection):
+        pass
+
+dialect = MSDialect_pymssql
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
new file mode 100644 (file)
index 0000000..9a2a9e4
--- /dev/null
@@ -0,0 +1,79 @@
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy import types as sqltypes
+import re
+import sys
+
+class MSExecutionContext_pyodbc(MSExecutionContext):
+    _embedded_scope_identity = False
+    
+    def pre_exec(self):
+        """where appropriate, issue "select scope_identity()" in the same statement.
+        
+        Background on why "scope_identity()" is preferable to "@@identity":
+        http://msdn.microsoft.com/en-us/library/ms190315.aspx
+        
+        Background on why we attempt to embed "scope_identity()" into the same
+        statement as the INSERT:
+        http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
+        
+        """
+        
+        super(MSExecutionContext_pyodbc, self).pre_exec()
+
+        # don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES"
+        if self._select_lastrowid and \
+                self.dialect.use_scope_identity and \
+                len(self.parameters[0]):
+            self._embedded_scope_identity = True
+            
+            self.statement += "; select scope_identity()"
+
+    def post_exec(self):
+        if self._embedded_scope_identity:
+            # Fetch the last inserted id from the manipulated statement
+            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
+            while True:
+                try:
+                    # fetchall() ensures the cursor is consumed without closing it (FreeTDS particularly)
+                    row = self.cursor.fetchall()[0]  
+                    break
+                except self.dialect.dbapi.Error, e:
+                    # no way around this - nextset() consumes the previous set
+                    # so we need to just keep flipping
+                    self.cursor.nextset()
+                    
+            self._lastrowid = int(row[0])
+        else:
+            super(MSExecutionContext_pyodbc, self).post_exec()
+
+
+class MSDialect_pyodbc(PyODBCConnector, MSDialect):
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = False
+
+    execution_ctx_cls = MSExecutionContext_pyodbc
+
+    pyodbc_driver_name = 'SQL Server'
+
+    def __init__(self, description_encoding='latin-1', **params):
+        super(MSDialect_pyodbc, self).__init__(**params)
+        self.description_encoding = description_encoding
+        self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
+        
+    def initialize(self, connection):
+        super(MSDialect_pyodbc, self).initialize(connection)
+        pyodbc = self.dbapi
+        
+        dbapi_con = connection.connection
+        
+        self._free_tds = re.match(r".*libtdsodbc.*\.so",  dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))
+    
+        # the "Py2K only" part here is theoretical.
+        # have not tried pyodbc + python3.1 yet.
+        # Py2K
+        self.supports_unicode_statements = not self._free_tds
+        self.supports_unicode_binds = not self._free_tds
+        # end Py2K
+        
+dialect = MSDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py
new file mode 100644 (file)
index 0000000..4106a29
--- /dev/null
@@ -0,0 +1,4 @@
+from sqlalchemy.dialects.mysql import base, mysqldb, pyodbc, zxjdbc
+
+# default dialect
+base.dialect = mysqldb.dialect
\ No newline at end of file
similarity index 66%
rename from lib/sqlalchemy/databases/mysql.py
rename to lib/sqlalchemy/dialects/mysql/base.py
index ba6b026ea29aac857be41bbe8563e904dfc2ff43..1c5c251e5441a7c7e1a8f47ab707dbe246f7fcba 100644 (file)
@@ -19,12 +19,12 @@ But if you would like to use one of the MySQL-specific or enhanced column
 types when creating tables with your :class:`~sqlalchemy.Table` definitions,
 then you will need to import them from this module::
 
-  from sqlalchemy.databases import mysql
+  from sqlalchemy.dialect.mysql import base as mysql
 
   Table('mytable', metadata,
         Column('id', Integer, primary_key=True),
-        Column('ittybittyblob', mysql.MSTinyBlob),
-        Column('biggy', mysql.MSBigInteger(unsigned=True)))
+        Column('ittybittyblob', mysql.TINYBLOB),
+        Column('biggy', mysql.BIGINT(unsigned=True)))
 
 All standard MySQL column types are supported.  The OpenGIS types are
 available for use via table reflection but have no special support or mapping
@@ -64,25 +64,6 @@ Nested Transactions                    5.0.3
 See the official MySQL documentation for detailed information about features
 supported in any given server release.
 
-Character Sets
---------------
-
-Many MySQL server installations default to a ``latin1`` encoding for client
-connections.  All data sent through the connection will be converted into
-``latin1``, even if you have ``utf8`` or another character set on your tables
-and columns.  With versions 4.1 and higher, you can change the connection
-character set either through server configuration or by including the
-``charset`` parameter in the URL used for ``create_engine``.  The ``charset``
-option is passed through to MySQL-Python and has the side-effect of also
-enabling ``use_unicode`` in the driver by default.  For regular encoded
-strings, also pass ``use_unicode=0`` in the connection arguments::
-
-  # set client encoding to utf8; all strings come back as unicode
-  create_engine('mysql:///mydb?charset=utf8')
-
-  # set client encoding to utf8; all strings come back as utf8 str
-  create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
-
 Storage Engines
 ---------------
 
@@ -196,27 +177,20 @@ timely information affecting MySQL in SQLAlchemy.
 
 """
 
-import datetime, decimal, inspect, re, sys
-from array import array as _array
+import datetime, inspect, re, sys
 
-from sqlalchemy import exc, log, schema, sql, util
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import exc, log, sql, util
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy.sql import functions as sql_functions
 from sqlalchemy.sql import compiler
+from array import array as _array
 
+from sqlalchemy.engine import reflection
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy import types as sqltypes
 
-
-__all__ = (
-    'MSBigInteger', 'MSMediumInteger', 'MSBinary', 'MSBit', 'MSBlob', 'MSBoolean',
-    'MSChar', 'MSDate', 'MSDateTime', 'MSDecimal', 'MSDouble',
-    'MSEnum', 'MSFloat', 'MSInteger', 'MSLongBlob', 'MSLongText',
-    'MSMediumBlob', 'MSMediumText', 'MSNChar', 'MSNVarChar',
-    'MSNumeric', 'MSSet', 'MSSmallInteger', 'MSString', 'MSText',
-    'MSTime', 'MSTimeStamp', 'MSTinyBlob', 'MSTinyInteger',
-    'MSTinyText', 'MSVarBinary', 'MSYear' )
-
+from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME
 
 RESERVED_WORDS = set(
     ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
@@ -271,62 +245,45 @@ SET_RE = re.compile(
 class _NumericType(object):
     """Base for MySQL numeric types."""
 
-    def __init__(self, kw):
+    def __init__(self, **kw):
         self.unsigned = kw.pop('unsigned', False)
         self.zerofill = kw.pop('zerofill', False)
+        super(_NumericType, self).__init__(**kw)
+        
+class _FloatType(_NumericType, sqltypes.Float):
+    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+        if isinstance(self, (REAL, DOUBLE)) and \
+            (
+                (precision is None and scale is not None) or
+                (precision is not None and scale is None)
+            ):
+           raise exc.ArgumentError(
+               "You must specify both precision and scale or omit "
+               "both altogether.")
+
+        super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw)
+        self.scale = scale
 
-    def _extend(self, spec):
-        "Extend a numeric-type declaration with MySQL specific extensions."
-
-        if self.unsigned:
-            spec += ' UNSIGNED'
-        if self.zerofill:
-            spec += ' ZEROFILL'
-        return spec
-
+class _IntegerType(_NumericType, sqltypes.Integer):
+    def __init__(self, display_width=None, **kw):
+        self.display_width = display_width
+        super(_IntegerType, self).__init__(**kw)
 
-class _StringType(object):
+class _StringType(sqltypes.String):
     """Base for MySQL string types."""
 
     def __init__(self, charset=None, collation=None,
                  ascii=False, unicode=False, binary=False,
-                 national=False, **kwargs):
+                 national=False, **kw):
         self.charset = charset
         # allow collate= or collation=
-        self.collation = kwargs.get('collate', collation)
+        self.collation = kw.pop('collate', collation)
         self.ascii = ascii
         self.unicode = unicode
         self.binary = binary
         self.national = national
-
-    def _extend(self, spec):
-        """Extend a string-type declaration with standard SQL CHARACTER SET /
-        COLLATE annotations and MySQL specific extensions.
-        """
-
-        if self.charset:
-            charset = 'CHARACTER SET %s' % self.charset
-        elif self.ascii:
-            charset = 'ASCII'
-        elif self.unicode:
-            charset = 'UNICODE'
-        else:
-            charset = None
-
-        if self.collation:
-            collation = 'COLLATE %s' % self.collation
-        elif self.binary:
-            collation = 'BINARY'
-        else:
-            collation = None
-
-        if self.national:
-            # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
-            return ' '.join([c for c in ('NATIONAL', spec, collation)
-                             if c is not None])
-        return ' '.join([c for c in (spec, charset, collation)
-                         if c is not None])
-
+        super(_StringType, self).__init__(**kw)
+        
     def __repr__(self):
         attributes = inspect.getargspec(self.__init__)[0][1:]
         attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
@@ -341,10 +298,23 @@ class _StringType(object):
                            ', '.join(['%s=%r' % (k, params[k]) for k in params]))
 
 
-class MSNumeric(sqltypes.Numeric, _NumericType):
-    """MySQL NUMERIC type."""
+class _BinaryType(sqltypes.Binary):
+    """Base for MySQL binary types."""
+
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            else:
+                return util.buffer(value)
+        return process
 
-    def __init__(self, precision=10, scale=2, asdecimal=True, **kw):
+class NUMERIC(_NumericType, sqltypes.NUMERIC):
+    """MySQL NUMERIC type."""
+    
+    __visit_name__ = 'NUMERIC'
+    
+    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         """Construct a NUMERIC.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -360,34 +330,15 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
           numeric.
 
         """
-        _NumericType.__init__(self, kw)
-        sqltypes.Numeric.__init__(self, precision, scale, asdecimal=asdecimal, **kw)
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return self._extend("NUMERIC")
-        else:
-            return self._extend("NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale})
-
-    def bind_processor(self, dialect):
-        return None
+        super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
 
-    def result_processor(self, dialect):
-        if not self.asdecimal:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-        else:
-            return None
 
-
-class MSDecimal(MSNumeric):
+class DECIMAL(_NumericType, sqltypes.DECIMAL):
     """MySQL DECIMAL type."""
-
-    def __init__(self, precision=10, scale=2, asdecimal=True, **kw):
+    
+    __visit_name__ = 'DECIMAL'
+    
+    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         """Construct a DECIMAL.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -403,20 +354,14 @@ class MSDecimal(MSNumeric):
           numeric.
 
         """
-        super(MSDecimal, self).__init__(precision, scale, asdecimal=asdecimal, **kw)
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return self._extend("DECIMAL")
-        elif self.scale is None:
-            return self._extend("DECIMAL(%(precision)s)" % {'precision': self.precision})
-        else:
-            return self._extend("DECIMAL(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale})
-
+        super(DECIMAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
 
-class MSDouble(sqltypes.Float, _NumericType):
+    
+class DOUBLE(_FloatType):
     """MySQL DOUBLE type."""
 
+    __visit_name__ = 'DOUBLE'
+
     def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         """Construct a DOUBLE.
 
@@ -433,29 +378,13 @@ class MSDouble(sqltypes.Float, _NumericType):
           numeric.
 
         """
-        if ((precision is None and scale is not None) or
-            (precision is not None and scale is None)):
-            raise exc.ArgumentError(
-                "You must specify both precision and scale or omit "
-                "both altogether.")
-
-        _NumericType.__init__(self, kw)
-        sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw)
-        self.scale = scale
-        self.precision = precision
+        super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
 
-    def get_col_spec(self):
-        if self.precision is not None and self.scale is not None:
-            return self._extend("DOUBLE(%(precision)s, %(scale)s)" %
-                                {'precision': self.precision,
-                                 'scale' : self.scale})
-        else:
-            return self._extend('DOUBLE')
-
-
-class MSReal(MSDouble):
+class REAL(_FloatType):
     """MySQL REAL type."""
 
+    __visit_name__ = 'REAL'
+
     def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         """Construct a REAL.
 
@@ -472,20 +401,13 @@ class MSReal(MSDouble):
           numeric.
 
         """
-        MSDouble.__init__(self, precision, scale, asdecimal, **kw)
+        super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
 
-    def get_col_spec(self):
-        if self.precision is not None and self.scale is not None:
-            return self._extend("REAL(%(precision)s, %(scale)s)" %
-                                {'precision': self.precision,
-                                 'scale' : self.scale})
-        else:
-            return self._extend('REAL')
-
-
-class MSFloat(sqltypes.Float, _NumericType):
+class FLOAT(_FloatType, sqltypes.FLOAT):
     """MySQL FLOAT type."""
 
+    __visit_name__ = 'FLOAT'
+
     def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
         """Construct a FLOAT.
 
@@ -502,26 +424,16 @@ class MSFloat(sqltypes.Float, _NumericType):
           numeric.
 
         """
-        _NumericType.__init__(self, kw)
-        sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw)
-        self.scale = scale
-        self.precision = precision
-
-    def get_col_spec(self):
-        if self.scale is not None and self.precision is not None:
-            return self._extend("FLOAT(%s, %s)" % (self.precision, self.scale))
-        elif self.precision is not None:
-            return self._extend("FLOAT(%s)" % (self.precision,))
-        else:
-            return self._extend("FLOAT")
+        super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
 
     def bind_processor(self, dialect):
         return None
 
-
-class MSInteger(sqltypes.Integer, _NumericType):
+class INTEGER(_IntegerType, sqltypes.INTEGER):
     """MySQL INTEGER type."""
 
+    __visit_name__ = 'INTEGER'
+
     def __init__(self, display_width=None, **kw):
         """Construct an INTEGER.
 
@@ -535,24 +447,13 @@ class MSInteger(sqltypes.Integer, _NumericType):
           numeric.
 
         """
-        if 'length' in kw:
-            util.warn_deprecated("'length' is deprecated for MSInteger and subclasses.  Use 'display_width'.")
-            self.display_width = kw.pop('length')
-        else:
-            self.display_width = display_width
-        _NumericType.__init__(self, kw)
-        sqltypes.Integer.__init__(self, **kw)
-
-    def get_col_spec(self):
-        if self.display_width is not None:
-            return self._extend("INTEGER(%(display_width)s)" % {'display_width': self.display_width})
-        else:
-            return self._extend("INTEGER")
+        super(INTEGER, self).__init__(display_width=display_width, **kw)
 
-
-class MSBigInteger(MSInteger):
+class BIGINT(_IntegerType, sqltypes.BIGINT):
     """MySQL BIGINTEGER type."""
 
+    __visit_name__ = 'BIGINT'
+
     def __init__(self, display_width=None, **kw):
         """Construct a BIGINTEGER.
 
@@ -566,18 +467,13 @@ class MSBigInteger(MSInteger):
           numeric.
 
         """
-        super(MSBigInteger, self).__init__(display_width, **kw)
-
-    def get_col_spec(self):
-        if self.display_width is not None:
-            return self._extend("BIGINT(%(display_width)s)" % {'display_width': self.display_width})
-        else:
-            return self._extend("BIGINT")
+        super(BIGINT, self).__init__(display_width=display_width, **kw)
 
-
-class MSMediumInteger(MSInteger):
+class MEDIUMINT(_IntegerType):
     """MySQL MEDIUMINTEGER type."""
 
+    __visit_name__ = 'MEDIUMINT'
+
     def __init__(self, display_width=None, **kw):
         """Construct a MEDIUMINTEGER
 
@@ -591,19 +487,13 @@ class MSMediumInteger(MSInteger):
           numeric.
 
         """
-        super(MSMediumInteger, self).__init__(display_width, **kw)
-
-    def get_col_spec(self):
-        if self.display_width is not None:
-            return self._extend("MEDIUMINT(%(display_width)s)" % {'display_width': self.display_width})
-        else:
-            return self._extend("MEDIUMINT")
+        super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
 
-
-
-class MSTinyInteger(MSInteger):
+class TINYINT(_IntegerType):
     """MySQL TINYINT type."""
 
+    __visit_name__ = 'TINYINT'
+
     def __init__(self, display_width=None, **kw):
         """Construct a TINYINT.
 
@@ -621,18 +511,13 @@ class MSTinyInteger(MSInteger):
           numeric.
 
         """
-        super(MSTinyInteger, self).__init__(display_width, **kw)
-
-    def get_col_spec(self):
-        if self.display_width is not None:
-            return self._extend("TINYINT(%s)" % self.display_width)
-        else:
-            return self._extend("TINYINT")
-
+        super(TINYINT, self).__init__(display_width=display_width, **kw)
 
-class MSSmallInteger(sqltypes.Smallinteger, MSInteger):
+class SMALLINT(_IntegerType, sqltypes.SMALLINT):
     """MySQL SMALLINTEGER type."""
 
+    __visit_name__ = 'SMALLINT'
+
     def __init__(self, display_width=None, **kw):
         """Construct a SMALLINTEGER.
 
@@ -646,18 +531,9 @@ class MSSmallInteger(sqltypes.Smallinteger, MSInteger):
           numeric.
 
         """
-        self.display_width = display_width
-        _NumericType.__init__(self, kw)
-        sqltypes.SmallInteger.__init__(self, **kw)
+        super(SMALLINT, self).__init__(display_width=display_width, **kw)
 
-    def get_col_spec(self):
-        if self.display_width is not None:
-            return self._extend("SMALLINT(%(display_width)s)" % {'display_width': self.display_width})
-        else:
-            return self._extend("SMALLINT")
-
-
-class MSBit(sqltypes.TypeEngine):
+class BIT(sqltypes.TypeEngine):
     """MySQL BIT type.
 
     This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for
@@ -666,6 +542,8 @@ class MSBit(sqltypes.TypeEngine):
 
     """
 
+    __visit_name__ = 'BIT'
+
     def __init__(self, length=None):
         """Construct a BIT.
 
@@ -685,32 +563,10 @@ class MSBit(sqltypes.TypeEngine):
             return value
         return process
 
-    def get_col_spec(self):
-        if self.length is not None:
-            return "BIT(%s)" % self.length
-        else:
-            return "BIT"
-
-
-class MSDateTime(sqltypes.DateTime):
-    """MySQL DATETIME type."""
-
-    def get_col_spec(self):
-        return "DATETIME"
-
-
-class MSDate(sqltypes.Date):
-    """MySQL DATE type."""
-
-    def get_col_spec(self):
-        return "DATE"
-
-
-class MSTime(sqltypes.Time):
+class _MSTime(sqltypes.Time):
     """MySQL TIME type."""
 
-    def get_col_spec(self):
-        return "TIME"
+    __visit_name__ = 'TIME'
 
     def result_processor(self, dialect):
         def process(value):
@@ -721,45 +577,24 @@ class MSTime(sqltypes.Time):
                 return None
         return process
 
-class MSTimeStamp(sqltypes.TIMESTAMP):
-    """MySQL TIMESTAMP type.
-
-    To signal the orm to automatically re-select modified rows to retrieve the
-    updated timestamp, add a ``server_default`` to your
-    :class:`~sqlalchemy.Column` specification::
-
-        from sqlalchemy.databases import mysql
-        Column('updated', mysql.MSTimeStamp,
-               server_default=sql.text('CURRENT_TIMESTAMP')
-              )
-
-    The full range of MySQL 4.1+ TIMESTAMP defaults can be specified in
-    the the default::
+class TIMESTAMP(sqltypes.TIMESTAMP):
+    """MySQL TIMESTAMP type."""
+    __visit_name__ = 'TIMESTAMP'
 
-        server_default=sql.text('CURRENT TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')
-
-    """
-
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-
-class MSYear(sqltypes.TypeEngine):
+class YEAR(sqltypes.TypeEngine):
     """MySQL YEAR type, for single byte storage of years 1901-2155."""
 
+    __visit_name__ = 'YEAR'
+
     def __init__(self, display_width=None):
         self.display_width = display_width
 
-    def get_col_spec(self):
-        if self.display_width is None:
-            return "YEAR"
-        else:
-            return "YEAR(%s)" % self.display_width
-
-class MSText(_StringType, sqltypes.Text):
+class TEXT(_StringType, sqltypes.TEXT):
     """MySQL TEXT type, for text up to 2^16 characters."""
 
-    def __init__(self, length=None, **kwargs):
+    __visit_name__ = 'TEXT'
+
+    def __init__(self, length=None, **kw):
         """Construct a TEXT.
 
         :param length: Optional, if provided the server may optimize storage
@@ -787,20 +622,13 @@ class MSText(_StringType, sqltypes.Text):
           only the collation of character data.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.Text.__init__(self, length,
-                               kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None))
+        super(TEXT, self).__init__(length=length, **kw)
 
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("TEXT(%d)" % self.length)
-        else:
-            return self._extend("TEXT")
-
-
-class MSTinyText(MSText):
+class TINYTEXT(_StringType):
     """MySQL TINYTEXT type, for text up to 2^8 characters."""
 
+    __visit_name__ = 'TINYTEXT'
+
     def __init__(self, **kwargs):
         """Construct a TINYTEXT.
 
@@ -825,16 +653,13 @@ class MSTinyText(MSText):
           only the collation of character data.
 
         """
+        super(TINYTEXT, self).__init__(**kwargs)
 
-        super(MSTinyText, self).__init__(**kwargs)
-
-    def get_col_spec(self):
-        return self._extend("TINYTEXT")
-
-
-class MSMediumText(MSText):
+class MEDIUMTEXT(_StringType):
     """MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
 
+    __visit_name__ = 'MEDIUMTEXT'
+
     def __init__(self, **kwargs):
         """Construct a MEDIUMTEXT.
 
@@ -859,15 +684,13 @@ class MSMediumText(MSText):
           only the collation of character data.
 
         """
-        super(MSMediumText, self).__init__(**kwargs)
+        super(MEDIUMTEXT, self).__init__(**kwargs)
 
-    def get_col_spec(self):
-        return self._extend("MEDIUMTEXT")
-
-
-class MSLongText(MSText):
+class LONGTEXT(_StringType):
     """MySQL LONGTEXT type, for text up to 2^32 characters."""
 
+    __visit_name__ = 'LONGTEXT'
+
     def __init__(self, **kwargs):
         """Construct a LONGTEXT.
 
@@ -892,15 +715,14 @@ class MSLongText(MSText):
           only the collation of character data.
 
         """
-        super(MSLongText, self).__init__(**kwargs)
-
-    def get_col_spec(self):
-        return self._extend("LONGTEXT")
+        super(LONGTEXT, self).__init__(**kwargs)
 
-
-class MSString(_StringType, sqltypes.String):
+    
+class VARCHAR(_StringType, sqltypes.VARCHAR):
     """MySQL VARCHAR type, for variable-length character data."""
 
+    __visit_name__ = 'VARCHAR'
+
     def __init__(self, length=None, **kwargs):
         """Construct a VARCHAR.
 
@@ -925,22 +747,15 @@ class MSString(_StringType, sqltypes.String):
           only the collation of character data.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.String.__init__(self, length,
-                                 kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None))
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("VARCHAR(%d)" % self.length)
-        else:
-            return self._extend("VARCHAR")
+        super(VARCHAR, self).__init__(length=length, **kwargs)
 
-
-class MSChar(_StringType, sqltypes.CHAR):
+class CHAR(_StringType, sqltypes.CHAR):
     """MySQL CHAR type, for fixed-length character data."""
 
+    __visit_name__ = 'CHAR'
+
     def __init__(self, length, **kwargs):
-        """Construct an NCHAR.
+        """Construct a CHAR.
 
         :param length: Maximum data length, in characters.
 
@@ -952,21 +767,17 @@ class MSChar(_StringType, sqltypes.CHAR):
           compatible with the national character set.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.CHAR.__init__(self, length,
-                               kwargs.get('convert_unicode', False))
-
-    def get_col_spec(self):
-        return self._extend("CHAR(%(length)s)" % {'length' : self.length})
-
+        super(CHAR, self).__init__(length=length, **kwargs)
 
-class MSNVarChar(_StringType, sqltypes.String):
+class NVARCHAR(_StringType, sqltypes.NVARCHAR):
     """MySQL NVARCHAR type.
 
     For variable-length character data in the server's configured national
     character set.
     """
 
+    __visit_name__ = 'NVARCHAR'
+
     def __init__(self, length=None, **kwargs):
         """Construct an NVARCHAR.
 
@@ -981,23 +792,18 @@ class MSNVarChar(_StringType, sqltypes.String):
 
         """
         kwargs['national'] = True
-        _StringType.__init__(self, **kwargs)
-        sqltypes.String.__init__(self, length,
-                                 kwargs.get('convert_unicode', False))
-
-    def get_col_spec(self):
-        # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
-        # of "NVARCHAR".
-        return self._extend("VARCHAR(%(length)s)" % {'length': self.length})
+        super(NVARCHAR, self).__init__(length=length, **kwargs)
 
 
-class MSNChar(_StringType, sqltypes.CHAR):
+class NCHAR(_StringType, sqltypes.NCHAR):
     """MySQL NCHAR type.
 
     For fixed-length character data in the server's configured national
     character set.
     """
 
+    __visit_name__ = 'NCHAR'
+
     def __init__(self, length=None, **kwargs):
         """Construct an NCHAR.  Arguments are:
 
@@ -1012,52 +818,28 @@ class MSNChar(_StringType, sqltypes.CHAR):
 
         """
         kwargs['national'] = True
-        _StringType.__init__(self, **kwargs)
-        sqltypes.CHAR.__init__(self, length,
-                               kwargs.get('convert_unicode', False))
-    def get_col_spec(self):
-        # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
-        return self._extend("CHAR(%(length)s)" % {'length': self.length})
-
-
-class _BinaryType(sqltypes.Binary):
-    """Base for MySQL binary types."""
+        super(NCHAR, self).__init__(length=length, **kwargs)
 
-    def get_col_spec(self):
-        if self.length:
-            return "BLOB(%d)" % self.length
-        else:
-            return "BLOB"
 
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            else:
-                return util.buffer(value)
-        return process
 
-class MSVarBinary(_BinaryType):
+class VARBINARY(_BinaryType):
     """MySQL VARBINARY type, for variable length binary data."""
 
+    __visit_name__ = 'VARBINARY'
+
     def __init__(self, length=None, **kw):
         """Construct a VARBINARY.  Arguments are:
 
         :param length: Maximum data length, in characters.
 
         """
-        super(MSVarBinary, self).__init__(length, **kw)
-
-    def get_col_spec(self):
-        if self.length:
-            return "VARBINARY(%d)" % self.length
-        else:
-            return "BLOB"
-
+        super(VARBINARY, self).__init__(length=length, **kw)
 
-class MSBinary(_BinaryType):
+class BINARY(_BinaryType):
     """MySQL BINARY type, for fixed length binary data"""
 
+    __visit_name__ = 'BINARY'
+
     def __init__(self, length=None, **kw):
         """Construct a BINARY.
 
@@ -1068,25 +850,13 @@ class MSBinary(_BinaryType):
           specified, this will generate a BLOB.  This usage is deprecated.
 
         """
-        super(MSBinary, self).__init__(length, **kw)
-
-    def get_col_spec(self):
-        if self.length:
-            return "BINARY(%d)" % self.length
-        else:
-            return "BLOB"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            else:
-                return util.buffer(value)
-        return process
+        super(BINARY, self).__init__(length=length, **kw)
 
-class MSBlob(_BinaryType):
+class BLOB(_BinaryType, sqltypes.BLOB):
     """MySQL BLOB type, for binary data up to 2^16 bytes"""
 
+    __visit_name__ = 'BLOB'
+
     def __init__(self, length=None, **kw):
         """Construct a BLOB.  Arguments are:
 
@@ -1095,50 +865,29 @@ class MSBlob(_BinaryType):
           ``length`` characters.
 
         """
-        super(MSBlob, self).__init__(length, **kw)
-
-    def get_col_spec(self):
-        if self.length:
-            return "BLOB(%d)" % self.length
-        else:
-            return "BLOB"
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            else:
-                return util.buffer(value)
-        return process
-
-    def __repr__(self):
-        return "%s()" % self.__class__.__name__
+        super(BLOB, self).__init__(length=length, **kw)
 
 
-class MSTinyBlob(MSBlob):
+class TINYBLOB(_BinaryType):
     """MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
+    
+    __visit_name__ = 'TINYBLOB'
 
-    def get_col_spec(self):
-        return "TINYBLOB"
-
-
-class MSMediumBlob(MSBlob):
+class MEDIUMBLOB(_BinaryType):
     """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
 
-    def get_col_spec(self):
-        return "MEDIUMBLOB"
-
+    __visit_name__ = 'MEDIUMBLOB'
 
-class MSLongBlob(MSBlob):
+class LONGBLOB(_BinaryType):
     """MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
 
-    def get_col_spec(self):
-        return "LONGBLOB"
-
+    __visit_name__ = 'LONGBLOB'
 
-class MSEnum(MSString):
+class ENUM(_StringType):
     """MySQL ENUM type."""
 
+    __visit_name__ = 'ENUM'
+
     def __init__(self, *enums, **kw):
         """Construct an ENUM.
 
@@ -1225,10 +974,10 @@ class MSEnum(MSString):
 
         self.strict = kw.pop('strict', False)
         length = max([len(v) for v in self.enums] + [0])
-        super(MSEnum, self).__init__(length, **kw)
+        super(ENUM, self).__init__(length=length, **kw)
 
     def bind_processor(self, dialect):
-        super_convert = super(MSEnum, self).bind_processor(dialect)
+        super_convert = super(ENUM, self).bind_processor(dialect)
         def process(value):
             if self.strict and value is not None and value not in self.enums:
                 raise exc.InvalidRequestError('"%s" not a valid value for '
@@ -1239,15 +988,11 @@ class MSEnum(MSString):
                 return value
         return process
 
-    def get_col_spec(self):
-        quoted_enums = []
-        for e in self.enums:
-            quoted_enums.append("'%s'" % e.replace("'", "''"))
-        return self._extend("ENUM(%s)" % ",".join(quoted_enums))
-
-class MSSet(MSString):
+class SET(_StringType):
     """MySQL SET type."""
 
+    __visit_name__ = 'SET'
+
     def __init__(self, *values, **kw):
         """Construct a SET.
 
@@ -1281,7 +1026,7 @@ class MSSet(MSString):
           only the collation of character data.
 
         """
-        self.__ddl_values = values
+        self._ddl_values = values
 
         strip_values = []
         for a in values:
@@ -1292,7 +1037,7 @@ class MSSet(MSString):
 
         self.values = strip_values
         length = max([len(v) for v in strip_values] + [0])
-        super(MSSet, self).__init__(length, **kw)
+        super(SET, self).__init__(length=length, **kw)
 
     def result_processor(self, dialect):
         def process(value):
@@ -1316,7 +1061,7 @@ class MSSet(MSString):
         return process
 
     def bind_processor(self, dialect):
-        super_convert = super(MSSet, self).bind_processor(dialect)
+        super_convert = super(SET, self).bind_processor(dialect)
         def process(value):
             if value is None or isinstance(value, (int, long, basestring)):
                 pass
@@ -1332,15 +1077,10 @@ class MSSet(MSString):
                 return value
         return process
 
-    def get_col_spec(self):
-        return self._extend("SET(%s)" % ",".join(self.__ddl_values))
-
-
-class MSBoolean(sqltypes.Boolean):
+class _MSBoolean(sqltypes.Boolean):
     """MySQL BOOLEAN type."""
 
-    def get_col_spec(self):
-        return "BOOL"
+    __visit_name__ = 'BOOLEAN'
 
     def result_processor(self, dialect):
         def process(value):
@@ -1361,74 +1101,91 @@ class MSBoolean(sqltypes.Boolean):
                 return value and True or False
         return process
 
+# old names
+MSBoolean = _MSBoolean
+MSTime = _MSTime
+MSSet = SET
+MSEnum = ENUM
+MSLongBlob = LONGBLOB
+MSMediumBlob = MEDIUMBLOB
+MSTinyBlob = TINYBLOB
+MSBlob = BLOB
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSNChar = NCHAR
+MSNVarChar = NVARCHAR
+MSChar = CHAR
+MSString = VARCHAR
+MSLongText = LONGTEXT
+MSMediumText = MEDIUMTEXT
+MSTinyText = TINYTEXT
+MSText = TEXT
+MSYear = YEAR
+MSTimeStamp = TIMESTAMP
+MSBit = BIT
+MSSmallInteger = SMALLINT
+MSTinyInteger = TINYINT
+MSMediumInteger = MEDIUMINT
+MSBigInteger = BIGINT
+MSNumeric = NUMERIC
+MSDecimal = DECIMAL
+MSDouble = DOUBLE
+MSReal = REAL
+MSFloat = FLOAT
+MSInteger = INTEGER
+
 colspecs = {
-    sqltypes.Integer: MSInteger,
-    sqltypes.Smallinteger: MSSmallInteger,
-    sqltypes.Numeric: MSNumeric,
-    sqltypes.Float: MSFloat,
-    sqltypes.DateTime: MSDateTime,
-    sqltypes.Date: MSDate,
-    sqltypes.Time: MSTime,
-    sqltypes.String: MSString,
-    sqltypes.Binary: MSBlob,
-    sqltypes.Boolean: MSBoolean,
-    sqltypes.Text: MSText,
-    sqltypes.CHAR: MSChar,
-    sqltypes.NCHAR: MSNChar,
-    sqltypes.TIMESTAMP: MSTimeStamp,
-    sqltypes.BLOB: MSBlob,
-    MSDouble: MSDouble,
-    MSReal: MSReal,
-    _BinaryType: _BinaryType,
+    sqltypes.Numeric: NUMERIC,
+    sqltypes.Float: FLOAT,
+    sqltypes.Binary: _BinaryType,
+    sqltypes.Boolean: _MSBoolean,
+    sqltypes.Time: _MSTime,
 }
 
 # Everything 3.23 through 5.1 excepting OpenGIS types.
 ischema_names = {
-    'bigint': MSBigInteger,
-    'binary': MSBinary,
-    'bit': MSBit,
-    'blob': MSBlob,
-    'boolean':MSBoolean,
-    'char': MSChar,
-    'date': MSDate,
-    'datetime': MSDateTime,
-    'decimal': MSDecimal,
-    'double': MSDouble,
-    'enum': MSEnum,
-    'fixed': MSDecimal,
-    'float': MSFloat,
-    'int': MSInteger,
-    'integer': MSInteger,
-    'longblob': MSLongBlob,
-    'longtext': MSLongText,
-    'mediumblob': MSMediumBlob,
-    'mediumint': MSMediumInteger,
-    'mediumtext': MSMediumText,
-    'nchar': MSNChar,
-    'nvarchar': MSNVarChar,
-    'numeric': MSNumeric,
-    'set': MSSet,
-    'smallint': MSSmallInteger,
-    'text': MSText,
-    'time': MSTime,
-    'timestamp': MSTimeStamp,
-    'tinyblob': MSTinyBlob,
-    'tinyint': MSTinyInteger,
-    'tinytext': MSTinyText,
-    'varbinary': MSVarBinary,
-    'varchar': MSString,
-    'year': MSYear,
+    'bigint': BIGINT,
+    'binary': BINARY,
+    'bit': BIT,
+    'blob': BLOB,
+    'boolean':BOOLEAN,
+    'char': CHAR,
+    'date': DATE,
+    'datetime': DATETIME,
+    'decimal': DECIMAL,
+    'double': DOUBLE,
+    'enum': ENUM,
+    'fixed': DECIMAL,
+    'float': FLOAT,
+    'int': INTEGER,
+    'integer': INTEGER,
+    'longblob': LONGBLOB,
+    'longtext': LONGTEXT,
+    'mediumblob': MEDIUMBLOB,
+    'mediumint': MEDIUMINT,
+    'mediumtext': MEDIUMTEXT,
+    'nchar': NCHAR,
+    'nvarchar': NVARCHAR,
+    'numeric': NUMERIC,
+    'set': SET,
+    'smallint': SMALLINT,
+    'text': TEXT,
+    'time': TIME,
+    'timestamp': TIMESTAMP,
+    'tinyblob': TINYBLOB,
+    'tinyint': TINYINT,
+    'tinytext': TINYTEXT,
+    'varbinary': VARBINARY,
+    'varchar': VARCHAR,
+    'year': YEAR,
 }
 
-
 class MySQLExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
-        if self.compiled.isinsert and not self.executemany:
-            if (not len(self._last_inserted_ids) or
-                self._last_inserted_ids[0] is None):
-                self._last_inserted_ids = ([self.cursor.lastrowid] +
-                                           self._last_inserted_ids[1:])
-        elif (not self.isupdate and not self.should_autocommit and
+        # TODO: i think this 'charset' in the info thing 
+        # is out
+        
+        if (not self.isupdate and not self.should_autocommit and
               self.statement and SET_RE.match(self.statement)):
             # This misses if a user forces autocommit on text('SET NAMES'),
             # which is probably a programming error anyhow.
@@ -1437,75 +1194,440 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_RE.match(statement)
 
+class MySQLCompiler(compiler.SQLCompiler):
 
-class MySQLDialect(default.DefaultDialect):
-    """Details of the MySQL dialect.  Not used directly in application code."""
-    name = 'mysql'
-    supports_alter = True
-    supports_unicode_statements = False
-    # identifiers are 64, however aliases can be 255...
-    max_identifier_length = 255
-    supports_sane_rowcount = True
-    default_paramstyle = 'format'
-
-    def __init__(self, use_ansiquotes=None, **kwargs):
-        self.use_ansiquotes = use_ansiquotes
-        default.DefaultDialect.__init__(self, **kwargs)
+    extract_map = compiler.SQLCompiler.extract_map.copy()
+    extract_map.update ({
+        'milliseconds': 'millisecond',
+    })
+    
+    def visit_random_func(self, fn, **kw):
+        return "rand%s" % self.function_argspec(fn)
+    
+    def visit_utc_timestamp_func(self, fn, **kw):
+        return "UTC_TIMESTAMP"
+        
+    def visit_concat_op(self, binary, **kw):
+        return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
+    def visit_match_op(self, binary, **kw):
+        return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right))
+        
+    def visit_typeclause(self, typeclause):
+        type_ = typeclause.type.dialect_impl(self.dialect)
+        if isinstance(type_, sqltypes.Integer):
+            if getattr(type_, 'unsigned', False):
+                return 'UNSIGNED INTEGER'
+            else:
+                return 'SIGNED INTEGER'
+        elif isinstance(type_, sqltypes.TIMESTAMP):
+            return 'DATETIME'
+        elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, sqltypes.Date, sqltypes.Time)):
+            return self.dialect.type_compiler.process(type_)
+        elif isinstance(type_, sqltypes.Text):
+            return 'CHAR'
+        elif (isinstance(type_, sqltypes.String) and not
+              isinstance(type_, (ENUM, SET))):
+            if getattr(type_, 'length'):
+                return 'CHAR(%s)' % type_.length
+            else:
+                return 'CHAR'
+        elif isinstance(type_, sqltypes.Binary):
+            return 'BINARY'
+        elif isinstance(type_, NUMERIC):
+            return self.dialect.type_compiler.process(type_).replace('NUMERIC', 'DECIMAL')
+        else:
+            return None
 
-    def dbapi(cls):
-        import MySQLdb as mysql
-        return mysql
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(database='db', username='user',
-                                          password='passwd')
-        opts.update(url.query)
-
-        util.coerce_kw_type(opts, 'compress', bool)
-        util.coerce_kw_type(opts, 'connect_timeout', int)
-        util.coerce_kw_type(opts, 'client_flag', int)
-        util.coerce_kw_type(opts, 'local_infile', int)
-        # Note: using either of the below will cause all strings to be returned
-        # as Unicode, both in raw SQL operations and with column types like
-        # String and MSString.
-        util.coerce_kw_type(opts, 'use_unicode', bool)
-        util.coerce_kw_type(opts, 'charset', str)
-
-        # Rich values 'cursorclass' and 'conv' are not supported via
-        # query string.
-
-        ssl = {}
-        for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
-            if key in opts:
-                ssl[key[4:]] = opts[key]
-                util.coerce_kw_type(ssl, key[4:], str)
-                del opts[key]
-        if ssl:
-            opts['ssl'] = ssl
-
-        # FOUND_ROWS must be set in CLIENT_FLAGS to enable
-        # supports_sane_rowcount.
-        client_flag = opts.get('client_flag', 0)
-        if self.dbapi is not None:
-            try:
-                import MySQLdb.constants.CLIENT as CLIENT_FLAGS
-                client_flag |= CLIENT_FLAGS.FOUND_ROWS
-            except:
-                pass
-            opts['client_flag'] = client_flag
-        return [[], opts]
+    def visit_cast(self, cast, **kwargs):
+        # No cast until 4, no decimals until 5.
+        type_ = self.process(cast.typeclause)
+        if type_ is None:
+            return self.process(cast.clause)
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
+        return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
 
-    def do_executemany(self, cursor, statement, parameters, context=None):
-        rowcount = cursor.executemany(statement, parameters)
-        if context is not None:
-            context._rowcount = rowcount
+    def get_select_precolumns(self, select):
+        if isinstance(select._distinct, basestring):
+            return select._distinct.upper() + " "
+        elif select._distinct:
+            return "DISTINCT "
+        else:
+            return ""
+
+    def visit_join(self, join, asfrom=False, **kwargs):
+        # 'JOIN ... ON ...' for inner joins isn't available until 4.0.
+        # Apparently < 3.23.17 requires theta joins for inner joins
+        # (but not outer).  Not generating these currently, but
+        # support can be added, preferably after dialects are
+        # refactored to be version-sensitive.
+        return ''.join(
+            (self.process(join.left, asfrom=True),
+             (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "),
+             self.process(join.right, asfrom=True),
+             " ON ",
+             self.process(join.onclause)))
+
+    def for_update_clause(self, select):
+        if select.for_update == 'read':
+            return ' LOCK IN SHARE MODE'
+        else:
+            return super(MySQLCompiler, self).for_update_clause(select)
+
+    def limit_clause(self, select):
+        # MySQL supports:
+        #   LIMIT <limit>
+        #   LIMIT <offset>, <limit>
+        # and in server versions > 3.3:
+        #   LIMIT <limit> OFFSET <offset>
+        # The latter is more readable for offsets but we're stuck with the
+        # former until we can refine dialects by server revision.
+
+        limit, offset = select._limit, select._offset
+
+        if (limit, offset) == (None, None):
+            return ''
+        elif offset is not None:
+            # As suggested by the MySQL docs, need to apply an
+            # artificial limit if one wasn't provided
+            if limit is None:
+                limit = 18446744073709551615
+            return ' \n LIMIT %s, %s' % (offset, limit)
+        else:
+            # No offset provided, so just use the limit
+            return ' \n LIMIT %s' % (limit,)
+
+    def visit_update(self, update_stmt):
+        self.stack.append({'from': set([update_stmt.table])})
+
+        self.isupdate = True
+        colparams = self._get_colparams(update_stmt)
+
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \
+                " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
+
+        if update_stmt._whereclause:
+            text += " WHERE " + self.process(update_stmt._whereclause)
+
+        limit = update_stmt.kwargs.get('mysql_limit', None)
+        if limit:
+            text += " LIMIT %s" % limit
+
+        self.stack.pop(-1)
+
+        return text
+
+# ug.  "InnoDB needs indexes on foreign keys and referenced keys [...].
+#       Starting with MySQL 4.1.2, these indexes are created automatically.
+#       In older versions, the indexes must be created explicitly or the
+#       creation of foreign key constraints fails."
+
+class MySQLDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kw):
+        """Builds column DDL."""
+
+        colspec = [self.preparer.format_column(column),
+                    self.dialect.type_compiler.process(column.type)
+                   ]
+
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec.append('DEFAULT ' + default)
+
+        if not column.nullable:
+            colspec.append('NOT NULL')
+
+        if column.primary_key and column.autoincrement:
+            try:
+                first = [c for c in column.table.primary_key.columns
+                         if (c.autoincrement and
+                             isinstance(c.type, sqltypes.Integer) and
+                             not c.foreign_keys)].pop(0)
+                if column is first:
+                    colspec.append('AUTO_INCREMENT')
+            except IndexError:
+                pass
+
+        return ' '.join(colspec)
 
-    def supports_unicode_statements(self):
-        return True
+    def post_create_table(self, table):
+        """Build table-level CREATE options like ENGINE and COLLATE."""
+
+        table_opts = []
+        for k in table.kwargs:
+            if k.startswith('mysql_'):
+                opt = k[6:].upper()
+                joiner = '='
+                if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
+                           'CHARACTER SET', 'COLLATE'):
+                    joiner = ' '
+
+                table_opts.append(joiner.join((opt, table.kwargs[k])))
+        return ' '.join(table_opts)
+
+    def visit_drop_index(self, drop):
+        index = drop.element
+        
+        return "\nDROP INDEX %s ON %s" % \
+                    (self.preparer.quote(self._validate_identifier(index.name, False), index.quote),
+                     self.preparer.format_table(index.table))
+
+    def visit_drop_constraint(self, drop):
+        constraint = drop.element
+        if isinstance(constraint, sa_schema.ForeignKeyConstraint):
+            qual = "FOREIGN KEY "
+            const = self.preparer.format_constraint(constraint)
+        elif isinstance(constraint, sa_schema.PrimaryKeyConstraint):
+            qual = "PRIMARY KEY "
+            const = ""
+        elif isinstance(constraint, sa_schema.UniqueConstraint):
+            qual = "INDEX "
+            const = self.preparer.format_constraint(constraint)
+        else:
+            qual = ""
+            const = self.preparer.format_constraint(constraint)
+        return "ALTER TABLE %s DROP %s%s" % \
+                    (self.preparer.format_table(constraint.table),
+                    qual, const)
+
+class MySQLTypeCompiler(compiler.GenericTypeCompiler):
+    def _extend_numeric(self, type_, spec):
+        "Extend a numeric-type declaration with MySQL specific extensions."
+
+        if not self._mysql_type(type_):
+            return spec
+
+        if type_.unsigned:
+            spec += ' UNSIGNED'
+        if type_.zerofill:
+            spec += ' ZEROFILL'
+        return spec
+
+    def _extend_string(self, type_, defaults, spec):
+        """Extend a string-type declaration with standard SQL CHARACTER SET /
+        COLLATE annotations and MySQL specific extensions.
+
+        """
+        
+        def attr(name):
+            return getattr(type_, name, defaults.get(name))
+            
+        if attr('charset'):
+            charset = 'CHARACTER SET %s' % attr('charset')
+        elif attr('ascii'):
+            charset = 'ASCII'
+        elif attr('unicode'):
+            charset = 'UNICODE'
+        else:
+            charset = None
+
+        if attr('collation'):
+            collation = 'COLLATE %s' % type_.collation
+        elif attr('binary'):
+            collation = 'BINARY'
+        else:
+            collation = None
+
+        if attr('national'):
+            # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
+            return ' '.join([c for c in ('NATIONAL', spec, collation)
+                             if c is not None])
+        return ' '.join([c for c in (spec, charset, collation)
+                         if c is not None])
+    
+    def _mysql_type(self, type_):
+        return isinstance(type_, (_StringType, _NumericType, _BinaryType))
+    
+    def visit_NUMERIC(self, type_):
+        if type_.precision is None:
+            return self._extend_numeric(type_, "NUMERIC")
+        elif type_.scale is None:
+            return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision})
+        else:
+            return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale})
+
+    def visit_DECIMAL(self, type_):
+        if type_.precision is None:
+            return self._extend_numeric(type_, "DECIMAL")
+        elif type_.scale is None:
+            return self._extend_numeric(type_, "DECIMAL(%(precision)s)" % {'precision': type_.precision})
+        else:
+            return self._extend_numeric(type_, "DECIMAL(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale})
+
+    def visit_DOUBLE(self, type_):
+        if type_.precision is not None and type_.scale is not None:
+            return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" %
+                                {'precision': type_.precision,
+                                 'scale' : type_.scale})
+        else:
+            return self._extend_numeric(type_, 'DOUBLE')
+
+    def visit_REAL(self, type_):
+        if type_.precision is not None and type_.scale is not None:
+            return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" %
+                                {'precision': type_.precision,
+                                 'scale' : type_.scale})
+        else:
+            return self._extend_numeric(type_, 'REAL')
+    
+    def visit_FLOAT(self, type_):
+        if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None:
+            return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale))
+        elif type_.precision is not None:
+            return self._extend_numeric(type_, "FLOAT(%s)" % (type_.precision,))
+        else:
+            return self._extend_numeric(type_, "FLOAT")
+    
+    def visit_INTEGER(self, type_):
+        if self._mysql_type(type_) and type_.display_width is not None:
+            return self._extend_numeric(type_, "INTEGER(%(display_width)s)" % {'display_width': type_.display_width})
+        else:
+            return self._extend_numeric(type_, "INTEGER")
+        
+    def visit_BIGINT(self, type_):
+        if self._mysql_type(type_) and type_.display_width is not None:
+            return self._extend_numeric(type_, "BIGINT(%(display_width)s)" % {'display_width': type_.display_width})
+        else:
+            return self._extend_numeric(type_, "BIGINT")
+    
+    def visit_MEDIUMINT(self, type_):
+        if self._mysql_type(type_) and type_.display_width is not None:
+            return self._extend_numeric(type_, "MEDIUMINT(%(display_width)s)" % {'display_width': type_.display_width})
+        else:
+            return self._extend_numeric(type_, "MEDIUMINT")
+
+    def visit_TINYINT(self, type_):
+        if self._mysql_type(type_) and type_.display_width is not None:
+            return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width)
+        else:
+            return self._extend_numeric(type_, "TINYINT")
+
+    def visit_SMALLINT(self, type_):
+        if self._mysql_type(type_) and type_.display_width is not None:
+            return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % {'display_width': type_.display_width})
+        else:
+            return self._extend_numeric(type_, "SMALLINT")
+
+    def visit_BIT(self, type_):
+        if type_.length is not None:
+            return "BIT(%s)" % type_.length
+        else:
+            return "BIT"
+    
+    def visit_DATETIME(self, type_):
+        return "DATETIME"
+
+    def visit_DATE(self, type_):
+        return "DATE"
+
+    def visit_TIME(self, type_):
+        return "TIME"
+
+    def visit_TIMESTAMP(self, type_):
+        return 'TIMESTAMP'
+
+    def visit_YEAR(self, type_):
+        if type_.display_width is None:
+            return "YEAR"
+        else:
+            return "YEAR(%s)" % type_.display_width
+    
+    def visit_TEXT(self, type_):
+        if type_.length:
+            return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
+        else:
+            return self._extend_string(type_, {}, "TEXT")
+        
+    def visit_TINYTEXT(self, type_):
+        return self._extend_string(type_, {}, "TINYTEXT")
+
+    def visit_MEDIUMTEXT(self, type_):
+        return self._extend_string(type_, {}, "MEDIUMTEXT")
+    
+    def visit_LONGTEXT(self, type_):
+        return self._extend_string(type_, {}, "LONGTEXT")
+    
+    def visit_VARCHAR(self, type_):
+        if type_.length:
+            return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
+        else:
+            return self._extend_string(type_, {}, "VARCHAR")
+    
+    def visit_CHAR(self, type_):
+        return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length' : type_.length})
+
+    def visit_NVARCHAR(self, type_):
+        # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
+        # of "NVARCHAR".
+        return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length})
+    
+    def visit_NCHAR(self, type_):
+        # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
+        return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length})
+    
+    def visit_VARBINARY(self, type_):
+        if type_.length:
+            return "VARBINARY(%d)" % type_.length
+        else:
+            return self.visit_BLOB(type_)
+    
+    def visit_binary(self, type_):
+        return self.visit_BLOB(type_)
+        
+    def visit_BINARY(self, type_):
+        if type_.length:
+            return "BINARY(%d)" % type_.length
+        else:
+            return self.visit_BLOB(type_)
+    
+    def visit_BLOB(self, type_):
+        if type_.length:
+            return "BLOB(%d)" % type_.length
+        else:
+            return "BLOB"
+    
+    def visit_TINYBLOB(self, type_):
+        return "TINYBLOB"
+
+    def visit_MEDIUMBLOB(self, type_):
+        return "MEDIUMBLOB"
+
+    def visit_LONGBLOB(self, type_):
+        return "LONGBLOB"
+
+    def visit_ENUM(self, type_):
+        quoted_enums = []
+        for e in type_.enums:
+            quoted_enums.append("'%s'" % e.replace("'", "''"))
+        return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums))
+        
+    def visit_SET(self, type_):
+        return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values))
+
+    def visit_BOOLEAN(self, type):
+        return "BOOL"
+        
+
+class MySQLDialect(default.DefaultDialect):
+    """Details of the MySQL dialect.  Not used directly in application code."""
+    name = 'mysql'
+    supports_alter = True
+    # identifiers are 64, however aliases can be 255...
+    max_identifier_length = 255
+    
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = False
+    
+    default_paramstyle = 'format'
+    colspecs = colspecs
+    
+    statement_compiler = MySQLCompiler
+    ddl_compiler = MySQLDDLCompiler
+    type_compiler = MySQLTypeCompiler
+    ischema_names = ischema_names
+    
+    def __init__(self, use_ansiquotes=None, **kwargs):
+        default.DefaultDialect.__init__(self, **kwargs)
 
     def do_commit(self, connection):
         """Execute a COMMIT."""
@@ -1518,7 +1640,7 @@ class MySQLDialect(default.DefaultDialect):
         try:
             connection.commit()
         except:
-            if self._server_version_info(connection) < (3, 23, 15):
+            if self.server_version_info < (3, 23, 15):
                 args = sys.exc_info()[1].args
                 if args and args[0] == 1064:
                     return
@@ -1530,59 +1652,66 @@ class MySQLDialect(default.DefaultDialect):
         try:
             connection.rollback()
         except:
-            if self._server_version_info(connection) < (3, 23, 15):
+            if self.server_version_info < (3, 23, 15):
                 args = sys.exc_info()[1].args
                 if args and args[0] == 1064:
                     return
             raise
 
     def do_begin_twophase(self, connection, xid):
-        connection.execute("XA BEGIN %s", xid)
+        connection.execute(sql.text("XA BEGIN :xid"), xid=xid)
 
     def do_prepare_twophase(self, connection, xid):
-        connection.execute("XA END %s", xid)
-        connection.execute("XA PREPARE %s", xid)
+        connection.execute(sql.text("XA END :xid"), xid=xid)
+        connection.execute(sql.text("XA PREPARE :xid"), xid=xid)
 
     def do_rollback_twophase(self, connection, xid, is_prepared=True,
                              recover=False):
         if not is_prepared:
-            connection.execute("XA END %s", xid)
-        connection.execute("XA ROLLBACK %s", xid)
+            connection.execute(sql.text("XA END :xid"), xid=xid)
+        connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid)
 
     def do_commit_twophase(self, connection, xid, is_prepared=True,
                            recover=False):
         if not is_prepared:
             self.do_prepare_twophase(connection, xid)
-        connection.execute("XA COMMIT %s", xid)
+        connection.execute(sql.text("XA COMMIT :xid"), xid=xid)
 
     def do_recover_twophase(self, connection):
         resultset = connection.execute("XA RECOVER")
         return [row['data'][0:row['gtrid_length']] for row in resultset]
 
-    def do_ping(self, connection):
-        connection.ping()
-
     def is_disconnect(self, e):
         if isinstance(e, self.dbapi.OperationalError):
-            return e.args[0] in (2006, 2013, 2014, 2045, 2055)
+            return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055)
         elif isinstance(e, self.dbapi.InterfaceError):  # if underlying connection is closed, this is the error you get
             return "(0, '')" in str(e)
         else:
             return False
 
+    def _compat_fetchall(self, rp, charset=None):
+        """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
+
+        return [_DecodingRowProxy(row, charset) for row in rp.fetchall()]
+
+    def _compat_fetchone(self, rp, charset=None):
+        """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
+
+        return _DecodingRowProxy(rp.fetchone(), charset)
+
+    def _extract_error_code(self, exception):
+        raise NotImplementedError()
+    
     def get_default_schema_name(self, connection):
         return connection.execute('SELECT DATABASE()').scalar()
-    get_default_schema_name = engine_base.connection_memoize(
-        ('dialect', 'default_schema_name'))(get_default_schema_name)
 
     def table_names(self, connection, schema):
         """Return a Unicode SHOW TABLES from a given schema."""
 
-        charset = self._detect_charset(connection)
-        self._autoset_identifier_style(connection)
+        charset = self._connection_charset
         rp = connection.execute("SHOW TABLES FROM %s" %
             self.identifier_preparer.quote_identifier(schema))
-        return [row[0] for row in _compat_fetchall(rp, charset=charset)]
+        return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
 
     def has_table(self, connection, table_name, schema=None):
         # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
@@ -1595,7 +1724,6 @@ class MySQLDialect(default.DefaultDialect):
         # full_name = self.identifier_preparer.format_table(table,
         #                                                   use_schema=True)
 
-        self._autoset_identifier_style(connection)
 
         full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
             schema, table_name))
@@ -1609,73 +1737,191 @@ class MySQLDialect(default.DefaultDialect):
                 rs.close()
                 return have
             except exc.SQLError, e:
-                if e.orig.args[0] == 1146:
+                if self._extract_error_code(e) == 1146:
                     return False
                 raise
         finally:
             if rs:
                 rs.close()
+    
+    def initialize(self, connection):
+        self.server_version_info = self._get_server_version_info(connection)
+        self._connection_charset = self._detect_charset(connection)
+        self._server_casing = self._detect_casing(connection)
+        self._server_collations = self._detect_collations(connection)
+        self._server_ansiquotes = self._detect_ansiquotes(connection)
+            
+        if self._server_ansiquotes:
+            self.preparer = MySQLANSIIdentifierPreparer
+        else:
+            self.preparer = MySQLIdentifierPreparer
+        self.identifier_preparer = self.preparer(self)
+
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
+        rp = connection.execute("SHOW schemas")
+        return [r[0] for r in rp]
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        if schema is None:
+            schema = self.get_default_schema_name(connection)
+        if self.server_version_info < (5, 0, 2):
+            return self.table_names(connection, schema)
+        charset = self._connection_charset
+        rp = connection.execute("SHOW FULL TABLES FROM %s" %
+                self.identifier_preparer.quote_identifier(schema))
+        
+        return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+                                                    if row[1] == 'BASE TABLE']
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        charset = self._connection_charset
+        if self.server_version_info < (5, 0, 2):
+            raise NotImplementedError
+        if schema is None:
+            schema = self.get_default_schema_name(connection)
+        if self.server_version_info < (5, 0, 2):
+            return self.table_names(connection, schema)
+        charset = self._connection_charset
+        rp = connection.execute("SHOW FULL TABLES FROM %s" %
+                self.identifier_preparer.quote_identifier(schema))
+        return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+                                                    if row[1] == 'VIEW']
+
+    @reflection.cache
+    def get_table_options(self, connection, table_name, schema=None, **kw):
+
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        return parsed_state.table_options
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        return parsed_state.columns
+
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        for key in parsed_state.keys:
+            if key['type'] == 'PRIMARY':
+                # There can be only one.
+                ##raise Exception, str(key)
+                return [s[0] for s in key['columns']]
+        return []
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        default_schema = None
+        
+        fkeys = []
 
-    def server_version_info(self, connection):
-        """A tuple of the database server version.
-
-        Formats the remote server version as a tuple of version values,
-        e.g. ``(5, 0, 44)``.  If there are strings in the version number
-        they will be in the tuple too, so don't count on these all being
-        ``int`` values.
-
-        This is a fast check that does not require a round trip.  It is also
-        cached per-Connection.
-        """
-
-        return self._server_version_info(connection.connection.connection)
-    server_version_info = engine_base.connection_memoize(
-        ('mysql', 'server_version_info'))(server_version_info)
+        for spec in parsed_state.constraints:
+            # only FOREIGN KEYs
+            ref_name = spec['table'][-1]
+            ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema
 
-    def _server_version_info(self, dbapi_con):
-        """Convert a MySQL-python server_info string into a tuple."""
+            if not ref_schema:
+                if default_schema is None:
+                    default_schema = \
+                        connection.dialect.get_default_schema_name(connection)
+                if schema == default_schema:
+                    ref_schema = schema
 
-        version = []
-        r = re.compile('[.\-]')
-        for n in r.split(dbapi_con.get_server_info()):
-            try:
-                version.append(int(n))
-            except ValueError:
-                version.append(n)
-        return tuple(version)
+            loc_names = spec['local']
+            ref_names = spec['foreign']
 
-    def reflecttable(self, connection, table, include_columns):
-        """Load column definitions from the server."""
+            con_kw = {}
+            for opt in ('name', 'onupdate', 'ondelete'):
+                if spec.get(opt, False):
+                    con_kw[opt] = spec[opt]
 
-        charset = self._detect_charset(connection)
-        self._autoset_identifier_style(connection)
+            fkey_d = {
+                'name' : spec['name'],
+                'constrained_columns' : loc_names,
+                'referred_schema' : ref_schema,
+                'referred_table' : ref_name,
+                'referred_columns' : ref_names,
+                'options' : con_kw
+            }
+            fkeys.append(fkey_d)
+        return fkeys
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None, **kw):
+
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        
+        indexes = []
+        for spec in parsed_state.keys:
+            unique = False
+            flavor = spec['type']
+            if flavor == 'PRIMARY':
+                continue
+            if flavor == 'UNIQUE':
+                unique = True
+            elif flavor in (None, 'FULLTEXT', 'SPATIAL'):
+                pass
+            else:
+                self.logger.info(
+                    "Converting unknown KEY type %s to a plain KEY" % flavor)
+                pass
+            index_d = {}
+            index_d['name'] = spec['name']
+            index_d['column_names'] = [s[0] for s in spec['columns']]
+            index_d['unique'] = unique
+            index_d['type'] = flavor
+            indexes.append(index_d)
+        return indexes
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+
+        charset = self._connection_charset
+        full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
+            schema, view_name))
+        sql = self._show_create_table(connection, None, charset,
+                                      full_name=full_name)
+        return sql
 
+    def _parsed_state_or_create(self, connection, table_name, schema=None, **kw):
+        return self._setup_parser(
+                        connection, 
+                        table_name, 
+                        schema, 
+                        info_cache=kw.get('info_cache', None)
+                    )
+        
+    @reflection.cache
+    def _setup_parser(self, connection, table_name, schema=None, **kw):
+        charset = self._connection_charset
         try:
-            reflector = self.reflector
+            parser = self.parser
         except AttributeError:
             preparer = self.identifier_preparer
-            if (self.server_version_info(connection) < (4, 1) and
-                self.use_ansiquotes):
+            if (self.server_version_info < (4, 1) and
+                self._server_ansiquotes):
                 # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
                 preparer = MySQLIdentifierPreparer(self)
-
-            self.reflector = reflector = MySQLSchemaReflector(preparer)
-
-        sql = self._show_create_table(connection, table, charset)
+            self.parser = parser = MySQLTableDefinitionParser(self, preparer)
+        full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
+            schema, table_name))
+        sql = self._show_create_table(connection, None, charset,
+                                      full_name=full_name)
         if sql.startswith('CREATE ALGORITHM'):
             # Adapt views to something table-like.
-            columns = self._describe_table(connection, table, charset)
-            sql = reflector._describe_to_create(table, columns)
-
-        self._adjust_casing(connection, table)
-
-        return reflector.reflect(connection, table, sql, charset,
-                                 only=include_columns)
-
-    def _adjust_casing(self, connection, table, charset=None):
+            columns = self._describe_table(connection, None, charset,
+                                           full_name=full_name)
+            sql = parser._describe_to_create(table_name, columns)
+        return parser.parse(sql, charset)
+  
+    def _adjust_casing(self, table, charset=None):
         """Adjust Table name to the server case sensitivity, if needed."""
 
-        casing = self._detect_casing(connection)
+        casing = self._server_casing
 
         # For winxx database hosts.  TODO: is this really needed?
         if casing == 1 and table.name != table.name.lower():
@@ -1683,50 +1929,8 @@ class MySQLDialect(default.DefaultDialect):
             lc_alias = schema._get_table_key(table.name, table.schema)
             table.metadata.tables[lc_alias] = table
 
-
     def _detect_charset(self, connection):
-        """Sniff out the character set in use for connection results."""
-
-        # Allow user override, won't sniff if force_charset is set.
-        if ('mysql', 'force_charset') in connection.info:
-            return connection.info[('mysql', 'force_charset')]
-
-        # Note: MySQL-python 1.2.1c7 seems to ignore changes made
-        # on a connection via set_character_set()
-        if self.server_version_info(connection) < (4, 1, 0):
-            try:
-                return connection.connection.character_set_name()
-            except AttributeError:
-                # < 1.2.1 final MySQL-python drivers have no charset support.
-                # a query is needed.
-                pass
-
-        # Prefer 'character_set_results' for the current connection over the
-        # value in the driver.  SET NAMES or individual variable SETs will
-        # change the charset without updating the driver's view of the world.
-        #
-        # If it's decided that issuing that sort of SQL leaves you SOL, then
-        # this can prefer the driver value.
-        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
-        opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)])
-
-        if 'character_set_results' in opts:
-            return opts['character_set_results']
-        try:
-            return connection.connection.character_set_name()
-        except AttributeError:
-            # Still no charset on < 1.2.1 final...
-            if 'character_set' in opts:
-                return opts['character_set']
-            else:
-                util.warn(
-                    "Could not detect the connection character set with this "
-                    "combination of MySQL server and MySQL-python. "
-                    "MySQL-python >= 1.2.2 is recommended.  Assuming latin1.")
-                return 'latin1'
-    _detect_charset = engine_base.connection_memoize(
-        ('mysql', 'charset'))(_detect_charset)
-
+        raise NotImplementedError()
 
     def _detect_casing(self, connection):
         """Sniff out identifier case sensitivity.
@@ -1737,8 +1941,8 @@ class MySQLDialect(default.DefaultDialect):
         """
         # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
 
-        charset = self._detect_charset(connection)
-        row = _compat_fetchone(connection.execute(
+        charset = self._connection_charset
+        row = self._compat_fetchone(connection.execute(
             "SHOW VARIABLES LIKE 'lower_case_table_names'"),
                                charset=charset)
         if not row:
@@ -1754,8 +1958,6 @@ class MySQLDialect(default.DefaultDialect):
                 cs = int(row[1])
             row.close()
         return cs
-    _detect_casing = engine_base.connection_memoize(
-        ('mysql', 'lower_case_table_names'))(_detect_casing)
 
     def _detect_collations(self, connection):
         """Pull the active COLLATIONS list from the server.
@@ -1764,49 +1966,22 @@ class MySQLDialect(default.DefaultDialect):
         """
 
         collations = {}
-        if self.server_version_info(connection) < (4, 1, 0):
+        if self.server_version_info < (4, 1, 0):
             pass
         else:
-            charset = self._detect_charset(connection)
+            charset = self._connection_charset
             rs = connection.execute('SHOW COLLATION')
-            for row in _compat_fetchall(rs, charset):
+            for row in self._compat_fetchall(rs, charset):
                 collations[row[0]] = row[1]
         return collations
-    _detect_collations = engine_base.connection_memoize(
-        ('mysql', 'collations'))(_detect_collations)
-
-    def use_ansiquotes(self, useansi):
-        self._use_ansiquotes = useansi
-        if useansi:
-            self.preparer = MySQLANSIIdentifierPreparer
-        else:
-            self.preparer = MySQLIdentifierPreparer
-        # icky
-        if hasattr(self, 'identifier_preparer'):
-            self.identifier_preparer = self.preparer(self)
-        if hasattr(self, 'reflector'):
-            del self.reflector
-
-    use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes,
-                              doc="True if ANSI_QUOTES is in effect.")
-
-    def _autoset_identifier_style(self, connection, charset=None):
-        """Detect and adjust for the ANSI_QUOTES sql mode.
 
-        If the dialect's use_ansiquotes is unset, query the server's sql mode
-        and reset the identifier style.
-
-        Note that this currently *only* runs during reflection.  Ideally this
-        would run the first time a connection pool connects to the database,
-        but the infrastructure for that is not yet in place.
-        """
-
-        if self.use_ansiquotes is not None:
-            return
+    def _detect_ansiquotes(self, connection):
+        """Detect and adjust for the ANSI_QUOTES sql mode."""
 
-        row = _compat_fetchone(
+        row = self._compat_fetchone(
             connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
-                               charset=charset)
+                               charset=self._connection_charset)
+
         if not row:
             mode = ''
         else:
@@ -1816,7 +1991,7 @@ class MySQLDialect(default.DefaultDialect):
                 mode_no = int(mode)
                 mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
 
-        self.use_ansiquotes = 'ANSI_QUOTES' in mode
+        return 'ANSI_QUOTES' in mode
 
     def _show_create_table(self, connection, table, charset=None,
                            full_name=None):
@@ -1831,11 +2006,11 @@ class MySQLDialect(default.DefaultDialect):
             try:
                 rp = connection.execute(st)
             except exc.SQLError, e:
-                if e.orig.args[0] == 1146:
+                if self._extract_error_code(e) == 1146:
                     raise exc.NoSuchTableError(full_name)
                 else:
                     raise
-            row = _compat_fetchone(rp, charset=charset)
+            row = self._compat_fetchone(rp, charset=charset)
             if not row:
                 raise exc.NoSuchTableError(full_name)
             return row[1].strip()
@@ -1858,326 +2033,163 @@ class MySQLDialect(default.DefaultDialect):
             try:
                 rp = connection.execute(st)
             except exc.SQLError, e:
-                if e.orig.args[0] == 1146:
+                if self._extract_error_code(e) == 1146:
                     raise exc.NoSuchTableError(full_name)
                 else:
                     raise
-            rows = _compat_fetchall(rp, charset=charset)
+            rows = self._compat_fetchall(rp, charset=charset)
         finally:
             if rp:
                 rp.close()
         return rows
 
-class _MySQLPythonRowProxy(object):
-    """Return consistent column values for all versions of MySQL-python.
-
-    Smooth over data type issues (esp. with alpha driver versions) and
-    normalize strings as Unicode regardless of user-configured driver
-    encoding settings.
-    """
-
-    # Some MySQL-python versions can return some columns as
-    # sets.Set(['value']) (seriously) but thankfully that doesn't
-    # seem to come up in DDL queries.
-
-    def __init__(self, rowproxy, charset):
-        self.rowproxy = rowproxy
-        self.charset = charset
-    def __getitem__(self, index):
-        item = self.rowproxy[index]
-        if isinstance(item, _array):
-            item = item.tostring()
-        if self.charset and isinstance(item, str):
-            return item.decode(self.charset)
-        else:
-            return item
-    def __getattr__(self, attr):
-        item = getattr(self.rowproxy, attr)
-        if isinstance(item, _array):
-            item = item.tostring()
-        if self.charset and isinstance(item, str):
-            return item.decode(self.charset)
-        else:
-            return item
-
-
-class MySQLCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
-    operators.update({
-        sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
-        sql_operators.mod: '%%',
-        sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
-    })
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update ({
-        sql_functions.random: 'rand%(expr)s',
-        "utc_timestamp":"UTC_TIMESTAMP"
-        })
-
-    extract_map = compiler.DefaultCompiler.extract_map.copy()
-    extract_map.update ({
-        'milliseconds': 'millisecond',
-    })
-
-    def visit_typeclause(self, typeclause):
-        type_ = typeclause.type.dialect_impl(self.dialect)
-        if isinstance(type_, MSInteger):
-            if getattr(type_, 'unsigned', False):
-                return 'UNSIGNED INTEGER'
-            else:
-                return 'SIGNED INTEGER'
-        elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)):
-            return type_.get_col_spec()
-        elif isinstance(type_, MSText):
-            return 'CHAR'
-        elif (isinstance(type_, _StringType) and not
-              isinstance(type_, (MSEnum, MSSet))):
-            if getattr(type_, 'length'):
-                return 'CHAR(%s)' % type_.length
-            else:
-                return 'CHAR'
-        elif isinstance(type_, _BinaryType):
-            return 'BINARY'
-        elif isinstance(type_, MSNumeric):
-            return type_.get_col_spec().replace('NUMERIC', 'DECIMAL')
-        elif isinstance(type_, MSTimeStamp):
-            return 'DATETIME'
-        elif isinstance(type_, (MSDateTime, MSDate, MSTime)):
-            return type_.get_col_spec()
-        else:
-            return None
-
-    def visit_cast(self, cast, **kwargs):
-        # No cast until 4, no decimals until 5.
-        type_ = self.process(cast.typeclause)
-        if type_ is None:
-            return self.process(cast.clause)
-
-        return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
-
-
-    def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.")
-        return text.replace('%', '%%')
-
-    def get_select_precolumns(self, select):
-        if isinstance(select._distinct, basestring):
-            return select._distinct.upper() + " "
-        elif select._distinct:
-            return "DISTINCT "
-        else:
-            return ""
-
-    def visit_join(self, join, asfrom=False, **kwargs):
-        # 'JOIN ... ON ...' for inner joins isn't available until 4.0.
-        # Apparently < 3.23.17 requires theta joins for inner joins
-        # (but not outer).  Not generating these currently, but
-        # support can be added, preferably after dialects are
-        # refactored to be version-sensitive.
-        return ''.join(
-            (self.process(join.left, asfrom=True),
-             (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "),
-             self.process(join.right, asfrom=True),
-             " ON ",
-             self.process(join.onclause)))
-
-    def for_update_clause(self, select):
-        if select.for_update == 'read':
-            return ' LOCK IN SHARE MODE'
-        else:
-            return super(MySQLCompiler, self).for_update_clause(select)
-
-    def limit_clause(self, select):
-        # MySQL supports:
-        #   LIMIT <limit>
-        #   LIMIT <offset>, <limit>
-        # and in server versions > 3.3:
-        #   LIMIT <limit> OFFSET <offset>
-        # The latter is more readable for offsets but we're stuck with the
-        # former until we can refine dialects by server revision.
-
-        limit, offset = select._limit, select._offset
-
-        if (limit, offset) == (None, None):
-            return ''
-        elif offset is not None:
-            # As suggested by the MySQL docs, need to apply an
-            # artificial limit if one wasn't provided
-            if limit is None:
-                limit = 18446744073709551615
-            return ' \n LIMIT %s, %s' % (offset, limit)
-        else:
-            # No offset provided, so just use the limit
-            return ' \n LIMIT %s' % (limit,)
-
-    def visit_update(self, update_stmt):
-        self.stack.append({'from': set([update_stmt.table])})
-
-        self.isupdate = True
-        colparams = self._get_colparams(update_stmt)
-
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
-
-        if update_stmt._whereclause:
-            text += " WHERE " + self.process(update_stmt._whereclause)
-
-        limit = update_stmt.kwargs.get('mysql_limit', None)
-        if limit:
-            text += " LIMIT %s" % limit
-
-        self.stack.pop(-1)
-
-        return text
-
-# ug.  "InnoDB needs indexes on foreign keys and referenced keys [...].
-#       Starting with MySQL 4.1.2, these indexes are created automatically.
-#       In older versions, the indexes must be created explicitly or the
-#       creation of foreign key constraints fails."
-
-class MySQLSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, first_pk=False):
-        """Builds column DDL."""
-
-        colspec = [self.preparer.format_column(column),
-                   column.type.dialect_impl(self.dialect).get_col_spec()]
-
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec.append('DEFAULT ' + default)
-
-        if not column.nullable:
-            colspec.append('NOT NULL')
-
-        if column.primary_key and column.autoincrement:
-            try:
-                first = [c for c in column.table.primary_key.columns
-                         if (c.autoincrement and
-                             isinstance(c.type, sqltypes.Integer) and
-                             not c.foreign_keys)].pop(0)
-                if column is first:
-                    colspec.append('AUTO_INCREMENT')
-            except IndexError:
-                pass
-
-        return ' '.join(colspec)
-
-    def post_create_table(self, table):
-        """Build table-level CREATE options like ENGINE and COLLATE."""
-
-        table_opts = []
-        for k in table.kwargs:
-            if k.startswith('mysql_'):
-                opt = k[6:].upper()
-                joiner = '='
-                if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
-                           'CHARACTER SET', 'COLLATE'):
-                    joiner = ' '
-
-                table_opts.append(joiner.join((opt, table.kwargs[k])))
-        return ' '.join(table_opts)
-
-
-class MySQLSchemaDropper(compiler.SchemaDropper):
-    def visit_index(self, index):
-        self.append("\nDROP INDEX %s ON %s" %
-                    (self.preparer.quote(self._validate_identifier(index.name, False), index.quote),
-                     self.preparer.format_table(index.table)))
-        self.execute()
-
-    def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP FOREIGN KEY %s" %
-                    (self.preparer.format_table(constraint.table),
-                     self.preparer.format_constraint(constraint)))
-        self.execute()
-
-
-class MySQLSchemaReflector(object):
-    """Parses SHOW CREATE TABLE output."""
-
-    def __init__(self, identifier_preparer):
-        """Construct a MySQLSchemaReflector.
-
-        identifier_preparer
-          An ANSIIdentifierPreparer type, used to determine the identifier
-          quoting style in effect.
-        """
-
-        self.preparer = identifier_preparer
+class ReflectedState(object):
+    """Stores raw information about a SHOW CREATE TABLE statement."""
+    
+    def __init__(self):
+        self.columns = []
+        self.table_options = {}
+        self.table_name = None
+        self.keys = []
+        self.constraints = []
+        
+class MySQLTableDefinitionParser(object):
+    """Parses the results of a SHOW CREATE TABLE statement."""
+    
+    def __init__(self, dialect, preparer):
+        self.dialect = dialect
+        self.preparer = preparer
         self._prep_regexes()
 
-    def reflect(self, connection, table, show_create, charset, only=None):
-        """Parse MySQL SHOW CREATE TABLE and fill in a ''Table''.
-
-        show_create
-          Unicode output of SHOW CREATE TABLE
-
-        table
-          A ''Table'', to be loaded with Columns, Indexes, etc.
-          table.name will be set if not already
-
-        charset
-          FIXME, some constructed values (like column defaults)
-          currently can't be Unicode.  ''charset'' will convert them
-          into the connection character set.
-
-        only
-           An optional sequence of column names.  If provided, only
-           these columns will be reflected, and any keys or constraints
-           that include columns outside this set will also be omitted.
-           That means that if ``only`` includes only one column in a
-           2 part primary key, the entire primary key will be omitted.
-        """
-
-        keys, constraints = [], []
-
-        if only:
-            only = set(only)
-
+    def parse(self, show_create, charset):
+        state = ReflectedState()
+        state.charset = charset
         for line in re.split(r'\r?\n', show_create):
             if line.startswith('  ' + self.preparer.initial_quote):
-                self._add_column(table, line, charset, only)
+                self._parse_column(line, state)
             # a regular table options line
             elif line.startswith(') '):
-                self._set_options(table, line)
+                self._parse_table_options(line, state)
             # an ANSI-mode table options line
             elif line == ')':
                 pass
             elif line.startswith('CREATE '):
-                self._set_name(table, line)
+                self._parse_table_name(line, state)
             # Not present in real reflection, but may be if loading from a file.
             elif not line:
                 pass
             else:
-                type_, spec = self.parse_constraints(line)
+                type_, spec = self._parse_constraints(line)
                 if type_ is None:
                     util.warn("Unknown schema content: %r" % line)
                 elif type_ == 'key':
-                    keys.append(spec)
+                    state.keys.append(spec)
                 elif type_ == 'constraint':
-                    constraints.append(spec)
+                    state.constraints.append(spec)
                 else:
                     pass
+                    
+        return state
+        
+    def _parse_constraints(self, line):
+        """Parse a KEY or CONSTRAINT line.
+
+        line
+          A line of SHOW CREATE TABLE output
+        """
+
+        # KEY
+        m = self._re_key.match(line)
+        if m:
+            spec = m.groupdict()
+            # convert columns into name, length pairs
+            spec['columns'] = self._parse_keyexprs(spec['columns'])
+            return 'key', spec
+
+        # CONSTRAINT
+        m = self._re_constraint.match(line)
+        if m:
+            spec = m.groupdict()
+            spec['table'] = \
+              self.preparer.unformat_identifiers(spec['table'])
+            spec['local'] = [c[0]
+                             for c in self._parse_keyexprs(spec['local'])]
+            spec['foreign'] = [c[0]
+                               for c in self._parse_keyexprs(spec['foreign'])]
+            return 'constraint', spec
+
+        # PARTITION and SUBPARTITION
+        m = self._re_partition.match(line)
+        if m:
+            # Punt!
+            return 'partition', line
+
+        # No match.
+        return (None, line)
+
+    def _parse_table_name(self, line, state):
+        """Extract the table name.
+
+        line
+          The first line of SHOW CREATE TABLE
+        """
+
+        regex, cleanup = self._pr_name
+        m = regex.match(line)
+        if m:
+            state.table_name = cleanup(m.group('name'))
+
+    def _parse_table_options(self, line, state):
+        """Build a dictionary of all reflected table-level options.
+
+        line
+          The final line of SHOW CREATE TABLE output.
+        """
+
+        options = {}
 
-        self._set_keys(table, keys, only)
-        self._set_constraints(table, constraints, connection, only)
+        if not line or line == ')':
+            pass
+
+        else:
+            r_eq_trim = self._re_options_util['=']
+
+            for regex, cleanup in self._pr_options:
+                m = regex.search(line)
+                if not m:
+                    continue
+                directive, value = m.group('directive'), m.group('val')
+                directive = r_eq_trim.sub('', directive).lower()
+                if cleanup:
+                    value = cleanup(value)
+                options[directive] = value
+
+        for nope in ('auto_increment', 'data_directory', 'index_directory'):
+            options.pop(nope, None)
+
+        for opt, val in options.items():
+            state.table_options['mysql_%s' % opt] = val
 
-    def _set_name(self, table, line):
-        """Override a Table name with the reflected name.
+    def _parse_column(self, line, state):
+        """Extract column details.
 
-        table
-          A ``Table``
+        Falls back to a 'minimal support' variant if full parse fails.
 
         line
-          The first line of SHOW CREATE TABLE output.
+          Any column-bearing line from SHOW CREATE TABLE
         """
 
-        # Don't override by default.
-        if table.name is None:
-            table.name = self.parse_name(line)
-
-    def _add_column(self, table, line, charset, only=None):
-        spec = self.parse_column(line)
+        charset = state.charset
+        spec = None
+        m = self._re_column.match(line)
+        if m:
+            spec = m.groupdict()
+            spec['full'] = True
+        else:
+            m = self._re_column_loose.match(line)
+            if m:
+                spec = m.groupdict()
+                spec['full'] = False
         if not spec:
             util.warn("Unknown column definition %r" % line)
             return
@@ -2187,18 +2199,13 @@ class MySQLSchemaReflector(object):
         name, type_, args, notnull = \
               spec['name'], spec['coltype'], spec['arg'], spec['notnull']
 
-        if only and name not in only:
-            self.logger.info("Omitting reflected column %s.%s" %
-                             (table.name, name))
-            return
-
         # Convention says that TINYINT(1) columns == BOOLEAN
         if type_ == 'tinyint' and args == '1':
             type_ = 'boolean'
             args = None
 
         try:
-            col_type = ischema_names[type_]
+            col_type = self.dialect.ischema_names[type_]
         except KeyError:
             util.warn("Did not recognize type '%s' of column '%s'" %
                       (type_, name))
@@ -2229,6 +2236,7 @@ class MySQLSchemaReflector(object):
         col_args, col_kw = [], {}
 
         # NOT NULL
+        col_kw['nullable'] = True
         if spec.get('notnull', False):
             col_kw['nullable'] = False
 
@@ -2240,131 +2248,64 @@ class MySQLSchemaReflector(object):
 
         # DEFAULT
         default = spec.get('default', None)
-        if default is not None and default != 'NULL':
-            # Defaults should be in the native charset for the moment
-            default = default.encode(charset)
-            if type_ == 'timestamp':
-                # can't be NULL for TIMESTAMPs
-                if (default[0], default[-1]) != ("'", "'"):
-                    default = sql.text(default)
-            else:
-                default = default[1:-1]
-            col_args.append(schema.DefaultClause(default))
-
-        table.append_column(schema.Column(name, type_instance,
-                                          *col_args, **col_kw))
 
-    def _set_keys(self, table, keys, only):
-        """Add ``Index`` and ``PrimaryKeyConstraint`` items to a ``Table``.
+        if default == 'NULL':
+            # eliminates the need to deal with this later.
+            default = None
+            
+        col_d = dict(name=name, type=type_instance, default=default)
+        col_d.update(col_kw)
+        state.columns.append(col_d)
 
-        Most of the information gets dropped here- more is reflected than
-        the schema objects can currently represent.
-
-        table
-          A ``Table``
+    def _describe_to_create(self, table_name, columns):
+        """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
 
-        keys
-          A sequence of key specifications produced by `constraints`
+        DESCRIBE is a much simpler reflection and is sufficient for
+        reflecting views for runtime use.  This method formats DDL
+        for columns only- keys are omitted.
 
-        only
-          Optional `set` of column names.  If provided, keys covering
-          columns not in this set will be omitted.
+        `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
+        SHOW FULL COLUMNS FROM rows must be rearranged for use with
+        this function.
         """
 
-        for spec in keys:
-            flavor = spec['type']
-            col_names = [s[0] for s in spec['columns']]
-
-            if only and not set(col_names).issubset(only):
-                if flavor is None:
-                    flavor = 'index'
-                self.logger.info(
-                    "Omitting %s KEY for (%s), key covers ommitted columns." %
-                    (flavor, ', '.join(col_names)))
-                continue
-
-            constraint = False
-            if flavor == 'PRIMARY':
-                key = schema.PrimaryKeyConstraint()
-                constraint = True
-            elif flavor == 'UNIQUE':
-                key = schema.Index(spec['name'], unique=True)
-            elif flavor in (None, 'FULLTEXT', 'SPATIAL'):
-                key = schema.Index(spec['name'])
-            else:
-                self.logger.info(
-                    "Converting unknown KEY type %s to a plain KEY" % flavor)
-                key = schema.Index(spec['name'])
-
-            for col in [table.c[name] for name in col_names]:
-                key.append_column(col)
-
-            if constraint:
-                table.append_constraint(key)
-
-    def _set_constraints(self, table, constraints, connection, only):
-        """Apply constraints to a ``Table``."""
-
-        default_schema = None
-
-        for spec in constraints:
-            # only FOREIGN KEYs
-            ref_name = spec['table'][-1]
-            ref_schema = len(spec['table']) > 1 and spec['table'][-2] or table.schema
-
-            if not ref_schema:
-                if default_schema is None:
-                    default_schema = connection.dialect.get_default_schema_name(
-                        connection)
-                if table.schema == default_schema:
-                    ref_schema = table.schema
-
-            loc_names = spec['local']
-            if only and not set(loc_names).issubset(only):
-                self.logger.info(
-                    "Omitting FOREIGN KEY for (%s), key covers ommitted "
-                    "columns." % (', '.join(loc_names)))
-                continue
-
-            ref_key = schema._get_table_key(ref_name, ref_schema)
-            if ref_key in table.metadata.tables:
-                ref_table = table.metadata.tables[ref_key]
-            else:
-                ref_table = schema.Table(
-                    ref_name, table.metadata, schema=ref_schema,
-                    autoload=True, autoload_with=connection)
-
-            ref_names = spec['foreign']
-
-            if ref_schema:
-                refspec = [".".join([ref_schema, ref_name, column]) for column in ref_names]
-            else:
-                refspec = [".".join([ref_name, column]) for column in ref_names]
-
-            con_kw = {}
-            for opt in ('name', 'onupdate', 'ondelete'):
-                if spec.get(opt, False):
-                    con_kw[opt] = spec[opt]
-
-            key = schema.ForeignKeyConstraint(loc_names, refspec, link_to_name=True, **con_kw)
-            table.append_constraint(key)
+        buffer = []
+        for row in columns:
+            (name, col_type, nullable, default, extra) = \
+                   [row[i] for i in (0, 1, 2, 4, 5)]
 
-    def _set_options(self, table, line):
-        """Apply safe reflected table options to a ``Table``.
+            line = [' ']
+            line.append(self.preparer.quote_identifier(name))
+            line.append(col_type)
+            if not nullable:
+                line.append('NOT NULL')
+            if default:
+                if 'auto_increment' in default:
+                    pass
+                elif (col_type.startswith('timestamp') and
+                      default.startswith('C')):
+                    line.append('DEFAULT')
+                    line.append(default)
+                elif default == 'NULL':
+                    line.append('DEFAULT')
+                    line.append(default)
+                else:
+                    line.append('DEFAULT')
+                    line.append("'%s'" % default.replace("'", "''"))
+            if extra:
+                line.append(extra)
 
-        table
-          A ``Table``
+            buffer.append(' '.join(line))
 
-        line
-          The final line of SHOW CREATE TABLE output.
-        """
+        return ''.join([('CREATE TABLE %s (\n' %
+                         self.preparer.quote_identifier(table_name)),
+                        ',\n'.join(buffer),
+                        '\n) '])
 
-        options = self.parse_table_options(line)
-        for nope in ('auto_increment', 'data_directory', 'index_directory'):
-            options.pop(nope, None)
+    def _parse_keyexprs(self, identifiers):
+        """Unpack '"col"(2),"col" ASC'-ish strings into components."""
 
-        for opt, val in options.items():
-            table.kwargs['mysql_%s' % opt] = val
+        return self._re_keyexprs.findall(identifiers)
 
     def _prep_regexes(self):
         """Pre-compile regular expressions."""
@@ -2522,154 +2463,42 @@ class MySQLSchemaReflector(object):
                  r'(?P<val>%s)' % (re.escape(directive), regex))
         self._pr_options.append(_pr_compile(regex))
 
+log.class_logger(MySQLTableDefinitionParser)
+log.class_logger(MySQLDialect)
 
-    def parse_name(self, line):
-        """Extract the table name.
-
-        line
-          The first line of SHOW CREATE TABLE
-        """
-
-        regex, cleanup = self._pr_name
-        m = regex.match(line)
-        if not m:
-            return None
-        return cleanup(m.group('name'))
-
-    def parse_column(self, line):
-        """Extract column details.
-
-        Falls back to a 'minimal support' variant if full parse fails.
-
-        line
-          Any column-bearing line from SHOW CREATE TABLE
-        """
-
-        m = self._re_column.match(line)
-        if m:
-            spec = m.groupdict()
-            spec['full'] = True
-            return spec
-        m = self._re_column_loose.match(line)
-        if m:
-            spec = m.groupdict()
-            spec['full'] = False
-            return spec
-        return None
-
-    def parse_constraints(self, line):
-        """Parse a KEY or CONSTRAINT line.
-
-        line
-          A line of SHOW CREATE TABLE output
-        """
-
-        # KEY
-        m = self._re_key.match(line)
-        if m:
-            spec = m.groupdict()
-            # convert columns into name, length pairs
-            spec['columns'] = self._parse_keyexprs(spec['columns'])
-            return 'key', spec
-
-        # CONSTRAINT
-        m = self._re_constraint.match(line)
-        if m:
-            spec = m.groupdict()
-            spec['table'] = \
-              self.preparer.unformat_identifiers(spec['table'])
-            spec['local'] = [c[0]
-                             for c in self._parse_keyexprs(spec['local'])]
-            spec['foreign'] = [c[0]
-                               for c in self._parse_keyexprs(spec['foreign'])]
-            return 'constraint', spec
-
-        # PARTITION and SUBPARTITION
-        m = self._re_partition.match(line)
-        if m:
-            # Punt!
-            return 'partition', line
-
-        # No match.
-        return (None, line)
-
-    def parse_table_options(self, line):
-        """Build a dictionary of all reflected table-level options.
-
-        line
-          The final line of SHOW CREATE TABLE output.
-        """
-
-        options = {}
-
-        if not line or line == ')':
-            return options
-
-        r_eq_trim = self._re_options_util['=']
-
-        for regex, cleanup in self._pr_options:
-            m = regex.search(line)
-            if not m:
-                continue
-            directive, value = m.group('directive'), m.group('val')
-            directive = r_eq_trim.sub('', directive).lower()
-            if cleanup:
-                value = cleanup(value)
-            options[directive] = value
-
-        return options
-
-    def _describe_to_create(self, table, columns):
-        """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
-
-        DESCRIBE is a much simpler reflection and is sufficient for
-        reflecting views for runtime use.  This method formats DDL
-        for columns only- keys are omitted.
-
-        `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
-        SHOW FULL COLUMNS FROM rows must be rearranged for use with
-        this function.
-        """
-
-        buffer = []
-        for row in columns:
-            (name, col_type, nullable, default, extra) = \
-                   [row[i] for i in (0, 1, 2, 4, 5)]
-
-            line = [' ']
-            line.append(self.preparer.quote_identifier(name))
-            line.append(col_type)
-            if not nullable:
-                line.append('NOT NULL')
-            if default:
-                if 'auto_increment' in default:
-                    pass
-                elif (col_type.startswith('timestamp') and
-                      default.startswith('C')):
-                    line.append('DEFAULT')
-                    line.append(default)
-                elif default == 'NULL':
-                    line.append('DEFAULT')
-                    line.append(default)
-                else:
-                    line.append('DEFAULT')
-                    line.append("'%s'" % default.replace("'", "''"))
-            if extra:
-                line.append(extra)
 
-            buffer.append(' '.join(line))
+class _DecodingRowProxy(object):
+    """Return unicode-decoded values based on type inspection.
 
-        return ''.join([('CREATE TABLE %s (\n' %
-                         self.preparer.quote_identifier(table.name)),
-                        ',\n'.join(buffer),
-                        '\n) '])
+    Smooth over data type issues (esp. with alpha driver versions) and
+    normalize strings as Unicode regardless of user-configured driver
+    encoding settings.
 
-    def _parse_keyexprs(self, identifiers):
-        """Unpack '"col"(2),"col" ASC'-ish strings into components."""
+    """
 
-        return self._re_keyexprs.findall(identifiers)
+    # Some MySQL-python versions can return some columns as
+    # sets.Set(['value']) (seriously) but thankfully that doesn't
+    # seem to come up in DDL queries.
 
-log.class_logger(MySQLSchemaReflector)
+    def __init__(self, rowproxy, charset):
+        self.rowproxy = rowproxy
+        self.charset = charset
+    def __getitem__(self, index):
+        item = self.rowproxy[index]
+        if isinstance(item, _array):
+            item = item.tostring()
+        if self.charset and isinstance(item, str):
+            return item.decode(self.charset)
+        else:
+            return item
+    def __getattr__(self, attr):
+        item = getattr(self.rowproxy, attr)
+        if isinstance(item, _array):
+            item = item.tostring()
+        if self.charset and isinstance(item, str):
+            return item.decode(self.charset)
+        else:
+            return item
 
 
 class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
@@ -2691,7 +2520,7 @@ class MySQLIdentifierPreparer(_MySQLIdentifierPreparer):
 
     def __init__(self, dialect):
         super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`")
-
+        
     def _escape_identifier(self, value):
         return value.replace('`', '``')
 
@@ -2704,17 +2533,6 @@ class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer):
 
     pass
 
-
-def _compat_fetchall(rp, charset=None):
-    """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
-
-    return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
-
-def _compat_fetchone(rp, charset=None):
-    """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
-
-    return _MySQLPythonRowProxy(rp.fetchone(), charset)
-
 def _pr_compile(regex, cleanup=None):
     """Prepare a 2-tuple of compiled regex and callable."""
 
@@ -2725,8 +2543,3 @@ def _re_compile(regex):
 
     return re.compile(regex, re.I | re.UNICODE)
 
-dialect = MySQLDialect
-dialect.statement_compiler = MySQLCompiler
-dialect.schemagenerator = MySQLSchemaGenerator
-dialect.schemadropper = MySQLSchemaDropper
-dialect.execution_ctx_cls = MySQLExecutionContext
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
new file mode 100644 (file)
index 0000000..6ecfc4b
--- /dev/null
@@ -0,0 +1,194 @@
+"""Support for the MySQL database via the MySQL-python adapter.
+
+Character Sets
+--------------
+
+Many MySQL server installations default to a ``latin1`` encoding for client
+connections.  All data sent through the connection will be converted into
+``latin1``, even if you have ``utf8`` or another character set on your tables
+and columns.  With versions 4.1 and higher, you can change the connection
+character set either through server configuration or by including the
+``charset`` parameter in the URL used for ``create_engine``.  The ``charset``
+option is passed through to MySQL-Python and has the side-effect of also
+enabling ``use_unicode`` in the driver by default.  For regular encoded
+strings, also pass ``use_unicode=0`` in the connection arguments::
+
+  # set client encoding to utf8; all strings come back as unicode
+  create_engine('mysql:///mydb?charset=utf8')
+
+  # set client encoding to utf8; all strings come back as utf8 str
+  create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
+"""
+
+import decimal
+import re
+
+from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext,
+                                            MySQLCompiler, NUMERIC, _NumericType)
+from sqlalchemy.engine import base as engine_base, default
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
+
+class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
+    
+    @property
+    def rowcount(self):
+        if hasattr(self, '_rowcount'):
+            return self._rowcount
+        else:
+            return self.cursor.rowcount
+        
+        
+class MySQL_mysqldbCompiler(MySQLCompiler):
+    def visit_mod(self, binary, **kw):
+        return self.process(binary.left) + " %% " + self.process(binary.right)
+    
+    def post_process_text(self, text):
+        return text.replace('%', '%%')
+
+
+class _DecimalType(_NumericType):
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return
+        def process(value):
+            if isinstance(value, decimal.Decimal):
+                return float(value)
+            else:
+                return value
+        return process
+
+
+class _MySQLdbNumeric(_DecimalType, NUMERIC):
+    pass
+
+
+class _MySQLdbDecimal(_DecimalType, DECIMAL):
+    pass
+
+
+class MySQL_mysqldb(MySQLDialect):
+    driver = 'mysqldb'
+    supports_unicode_statements = False
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = True
+
+    default_paramstyle = 'format'
+    execution_ctx_cls = MySQL_mysqldbExecutionContext
+    statement_compiler = MySQL_mysqldbCompiler
+
+    colspecs = util.update_copy(
+        MySQLDialect.colspecs,
+        {
+            sqltypes.Numeric: _MySQLdbNumeric,
+            DECIMAL: _MySQLdbDecimal
+        }
+    )
+    
+    @classmethod
+    def dbapi(cls):
+        return __import__('MySQLdb')
+
+    def do_executemany(self, cursor, statement, parameters, context=None):
+        rowcount = cursor.executemany(statement, parameters)
+        if context is not None:
+            context._rowcount = rowcount
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(database='db', username='user',
+                                          password='passwd')
+        opts.update(url.query)
+
+        util.coerce_kw_type(opts, 'compress', bool)
+        util.coerce_kw_type(opts, 'connect_timeout', int)
+        util.coerce_kw_type(opts, 'client_flag', int)
+        util.coerce_kw_type(opts, 'local_infile', int)
+        # Note: using either of the below will cause all strings to be returned
+        # as Unicode, both in raw SQL operations and with column types like
+        # String and MSString.
+        util.coerce_kw_type(opts, 'use_unicode', bool)
+        util.coerce_kw_type(opts, 'charset', str)
+
+        # Rich values 'cursorclass' and 'conv' are not supported via
+        # query string.
+
+        ssl = {}
+        for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
+            if key in opts:
+                ssl[key[4:]] = opts[key]
+                util.coerce_kw_type(ssl, key[4:], str)
+                del opts[key]
+        if ssl:
+            opts['ssl'] = ssl
+
+        # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+        # supports_sane_rowcount.
+        client_flag = opts.get('client_flag', 0)
+        if self.dbapi is not None:
+            try:
+                from MySQLdb.constants import CLIENT as CLIENT_FLAGS
+                client_flag |= CLIENT_FLAGS.FOUND_ROWS
+            except:
+                pass
+            opts['client_flag'] = client_flag
+        return [[], opts]
+    
+    def do_ping(self, connection):
+        connection.ping()
+
+    def _get_server_version_info(self, connection):
+        dbapi_con = connection.connection
+        version = []
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.get_server_info()):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
+
+    def _extract_error_code(self, exception):
+        try:
+            return exception.orig.args[0]
+        except AttributeError:
+            return None
+
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+
+        # Note: MySQL-python 1.2.1c7 seems to ignore changes made
+        # on a connection via set_character_set()
+        if self.server_version_info < (4, 1, 0):
+            try:
+                return connection.connection.character_set_name()
+            except AttributeError:
+                # < 1.2.1 final MySQL-python drivers have no charset support.
+                # a query is needed.
+                pass
+
+        # Prefer 'character_set_results' for the current connection over the
+        # value in the driver.  SET NAMES or individual variable SETs will
+        # change the charset without updating the driver's view of the world.
+        #
+        # If it's decided that issuing that sort of SQL leaves you SOL, then
+        # this can prefer the driver value.
+        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
+
+        if 'character_set_results' in opts:
+            return opts['character_set_results']
+        try:
+            return connection.connection.character_set_name()
+        except AttributeError:
+            # Still no charset on < 1.2.1 final...
+            if 'character_set' in opts:
+                return opts['character_set']
+            else:
+                util.warn(
+                    "Could not detect the connection character set with this "
+                    "combination of MySQL server and MySQL-python. "
+                    "MySQL-python >= 1.2.2 is recommended.  Assuming latin1.")
+                return 'latin1'
+
+
+dialect = MySQL_mysqldb
diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py
new file mode 100644 (file)
index 0000000..1ea7ec8
--- /dev/null
@@ -0,0 +1,54 @@
+from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy.engine import base as engine_base
+from sqlalchemy import util
+import re
+
+class MySQL_pyodbcExecutionContext(MySQLExecutionContext):
+
+    def get_lastrowid(self):
+        cursor = self.create_cursor()
+        cursor.execute("SELECT LAST_INSERT_ID()")
+        lastrowid = cursor.fetchone()[0]
+        cursor.close()
+        return lastrowid
+
+class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
+    supports_unicode_statements = False
+    execution_ctx_cls = MySQL_pyodbcExecutionContext
+
+    pyodbc_driver_name = "MySQL"
+    
+    def __init__(self, **kw):
+        # deal with http://code.google.com/p/pyodbc/issues/detail?id=25
+        kw.setdefault('convert_unicode', True)
+        MySQLDialect.__init__(self, **kw)
+        PyODBCConnector.__init__(self, **kw)
+
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+
+        # Prefer 'character_set_results' for the current connection over the
+        # value in the driver.  SET NAMES or individual variable SETs will
+        # change the charset without updating the driver's view of the world.
+        #
+        # If it's decided that issuing that sort of SQL leaves you SOL, then
+        # this can prefer the driver value.
+        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
+        for key in ('character_set_connection', 'character_set'):
+            if opts.get(key, None):
+                return opts[key]
+
+        util.warn("Could not detect the connection character set.  Assuming latin1.")
+        return 'latin1'
+    
+    def _extract_error_code(self, exception):
+        m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
+        c = m.group(1)
+        if c:
+            return int(c)
+        else:
+            return None
+
+dialect = MySQL_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py
new file mode 100644 (file)
index 0000000..6cdc6f4
--- /dev/null
@@ -0,0 +1,95 @@
+"""Support for the MySQL database via Jython's zxjdbc JDBC connector.
+
+JDBC Driver
+-----------
+
+The official MySQL JDBC driver is at
+http://dev.mysql.com/downloads/connector/j/.
+
+"""
+import re
+
+from sqlalchemy import types as sqltypes, util
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
+
+class _JDBCBit(BIT):
+    def result_processor(self, dialect):
+        """Converts boolean or byte arrays from MySQL Connector/J to longs."""
+        def process(value):
+            if value is None:
+                return value
+            if isinstance(value, bool):
+                return int(value)
+            v = 0L
+            for i in value:
+                v = v << 8 | (i & 0xff)
+            value = v
+            return value
+        return process
+
+
+class MySQL_jdbcExecutionContext(MySQLExecutionContext):
+    def get_lastrowid(self):
+        cursor = self.create_cursor()
+        cursor.execute("SELECT LAST_INSERT_ID()")
+        lastrowid = cursor.fetchone()[0]
+        cursor.close()
+        return lastrowid
+
+
+class MySQL_jdbc(ZxJDBCConnector, MySQLDialect):
+    execution_ctx_cls = MySQL_jdbcExecutionContext
+
+    jdbc_db_name = 'mysql'
+    jdbc_driver_name = 'com.mysql.jdbc.Driver'
+
+    colspecs = util.update_copy(
+        MySQLDialect.colspecs,
+        {
+            sqltypes.Time: sqltypes.Time,
+            BIT: _JDBCBit
+        }
+    )
+
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+        # Prefer 'character_set_results' for the current connection over the
+        # value in the driver.  SET NAMES or individual variable SETs will
+        # change the charset without updating the driver's view of the world.
+        #
+        # If it's decided that issuing that sort of SQL leaves you SOL, then
+        # this can prefer the driver value.
+        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+        opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
+        for key in ('character_set_connection', 'character_set'):
+            if opts.get(key, None):
+                return opts[key]
+
+        util.warn("Could not detect the connection character set.  Assuming latin1.")
+        return 'latin1'
+
+    def _driver_kwargs(self):
+        """return kw arg dict to be sent to connect()."""
+        return dict(CHARSET=self.encoding, yearIsDateType='false')
+
+    def _extract_error_code(self, exception):
+        # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
+        # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
+        m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args))
+        c = m.group(1)
+        if c:
+            return int(c)
+
+    def _get_server_version_info(self,connection):
+        dbapi_con = connection.connection
+        version = []
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.dbversion):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
+
+dialect = MySQL_jdbc
diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py
new file mode 100644 (file)
index 0000000..3b4379a
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
+
+base.dialect = cx_oracle.dialect
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
new file mode 100644 (file)
index 0000000..419cced
--- /dev/null
@@ -0,0 +1,904 @@
+# oracle/base.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Support for the Oracle database.
+
+Oracle version 8 through current (11g at the time of this writing) are supported.
+
+For information on connecting via specific drivers, see the documentation
+for that driver.
+
+Connect Arguments
+-----------------
+
+The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which 
+affect the behavior of the dialect regardless of driver in use.
+
+* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8).  Defaults
+  to ``True``.  If ``False``, Oracle-8 compatible constructs are used for joins.
+
+* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET.
+
+Auto Increment Behavior
+-----------------------
+
+SQLAlchemy Table objects which include integer primary keys are usually assumed to have
+"autoincrementing" behavior, meaning they can generate their own primary key values upon
+INSERT.  Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences 
+to produce these values.   With the Oracle dialect, *a sequence must always be explicitly
+specified to enable autoincrement*.  This is divergent with the majority of documentation 
+examples which assume the usage of an autoincrement-capable database.   To specify sequences,
+use the sqlalchemy.schema.Sequence object which is passed to a Column construct::
+
+  t = Table('mytable', metadata, 
+        Column('id', Integer, Sequence('id_seq'), primary_key=True),
+        Column(...), ...
+  )
+
+This step is also required when using table reflection, i.e. autoload=True::
+
+  t = Table('mytable', metadata, 
+        Column('id', Integer, Sequence('id_seq'), primary_key=True),
+        autoload=True
+  ) 
+
+Identifier Casing
+-----------------
+
+In Oracle, the data dictionary represents all case insensitive identifier names 
+using UPPERCASE text.   SQLAlchemy on the other hand considers an all-lower case identifier
+name to be case insensitive.   The Oracle dialect converts all case insensitive identifiers
+to and from those two formats during schema level communication, such as reflection of
+tables and indexes.   Using an UPPERCASE name on the SQLAlchemy side indicates a 
+case sensitive identifier, and SQLAlchemy will quote the name - this will cause mismatches
+against data dictionary data received from Oracle, so unless identifier names have been
+truly created as case sensitive (i.e. using quoted names), all lowercase names should be
+used on the SQLAlchemy side.
+
+Unicode
+-------
+
+SQLAlchemy 0.6 uses the "native unicode" mode provided as of cx_oracle 5.  cx_oracle 5.0.2
+or greater is recommended for support of NCLOB.   If not using cx_oracle 5, the NLS_LANG
+environment variable needs to be set in order for the oracle client library to use 
+proper encoding, such as "AMERICAN_AMERICA.UTF8".
+
+Also note that Oracle supports unicode data through the NVARCHAR and NCLOB data types.
+When using the SQLAlchemy Unicode and UnicodeText types, these DDL types will be used
+within CREATE TABLE statements.   Usage of VARCHAR2 and CLOB with unicode text still 
+requires NLS_LANG to be set.
+
+LIMIT/OFFSET Support
+--------------------
+
+Oracle has no support for the LIMIT or OFFSET keywords.  Whereas previous versions of SQLAlchemy
+used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses 
+a wrapped subquery approach in conjunction with ROWNUM.  The exact methodology is taken from
+http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html .  Note that the 
+"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt
+this was stepping into the bounds of optimization that is better left on the DBA side, but this
+prefix can be added by enabling the optimize_limits=True flag on create_engine().
+
+ON UPDATE CASCADE
+-----------------
+
+Oracle doesn't have native ON UPDATE CASCADE functionality.  A trigger based solution 
+is available at http://asktom.oracle.com/tkyte/update_cascade/index.html .
+
+When using the SQLAlchemy ORM, the ORM has limited ability to manually issue
+cascading updates - specify ForeignKey objects using the 
+"deferrable=True, initially='deferred'" keyword arguments,
+and specify "passive_updates=False" on each relation().
+
+Oracle 8 Compatibility
+----------------------
+
+When using Oracle 8, a "use_ansi=False" flag is available which converts all
+JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN
+makes use of Oracle's (+) operator.
+
+Synonym/DBLINK Reflection
+-------------------------
+
+When using reflection with Table objects, the dialect can optionally search for tables
+indicated by synonyms that reference DBLINK-ed tables by passing the flag 
+oracle_resolve_synonyms=True as a keyword argument to the Table construct.  If DBLINK 
+is not in use this flag should be left off.
+
+"""
+
+import random, re
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import util, sql, log
+from sqlalchemy.engine import default, base, reflection
+from sqlalchemy.sql import compiler, visitors, expression
+from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
+from sqlalchemy import types as sqltypes
+from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, \
+                BLOB, CLOB, TIMESTAMP, FLOAT
+                
+RESERVED_WORDS = set('''SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC AND START UID COMMENT'''.split()) 
+
+class RAW(sqltypes.Binary):
+    pass
+OracleRaw = RAW
+
+class NCLOB(sqltypes.Text):
+    __visit_name__ = 'NCLOB'
+
+VARCHAR2 = VARCHAR
+NVARCHAR2 = NVARCHAR
+
+class NUMBER(sqltypes.Numeric):
+    __visit_name__ = 'NUMBER'
+    
+class BFILE(sqltypes.Binary):
+    __visit_name__ = 'BFILE'
+
+class DOUBLE_PRECISION(sqltypes.Numeric):
+    __visit_name__ = 'DOUBLE_PRECISION'
+
+class LONG(sqltypes.Text):
+    __visit_name__ = 'LONG'
+    
+class _OracleBoolean(sqltypes.Boolean):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.NUMBER
+    
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+
+colspecs = {
+    sqltypes.Boolean : _OracleBoolean,
+}
+
+ischema_names = {
+    'VARCHAR2' : VARCHAR,
+    'NVARCHAR2' : NVARCHAR,
+    'CHAR' : CHAR,
+    'DATE' : DATE,
+    'DATETIME' : DATETIME,
+    'NUMBER' : NUMBER,
+    'BLOB' : BLOB,
+    'BFILE' : BFILE,
+    'CLOB' : CLOB,
+    'NCLOB' : NCLOB,
+    'TIMESTAMP' : TIMESTAMP,
+    'RAW' : RAW,
+    'FLOAT' : FLOAT,
+    'DOUBLE PRECISION' : DOUBLE_PRECISION,
+    'LONG' : LONG,
+}
+
+
+class OracleTypeCompiler(compiler.GenericTypeCompiler):
+    # Note:
+    # Oracle DATE == DATETIME
+    # Oracle does not allow milliseconds in DATE
+    # Oracle does not support TIME columns
+    
+    def visit_datetime(self, type_):
+        return self.visit_DATE(type_)
+    
+    def visit_float(self, type_):
+        if type_.precision is None:
+            return "NUMERIC"
+        else:
+            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : 2}
+        
+    def visit_unicode(self, type_):
+        return self.visit_NVARCHAR(type_)
+        
+    def visit_VARCHAR(self, type_):
+        return "VARCHAR(%(length)s)" % {'length' : type_.length}
+
+    def visit_NVARCHAR(self, type_):
+        return "NVARCHAR2(%(length)s)" % {'length' : type_.length}
+    
+    def visit_text(self, type_):
+        return self.visit_CLOB(type_)
+
+    def visit_unicode_text(self, type_):
+        return self.visit_NCLOB(type_)
+
+    def visit_binary(self, type_):
+        return self.visit_BLOB(type_)
+    
+    def visit_boolean(self, type_):
+        return self.visit_SMALLINT(type_)
+    
+    def visit_RAW(self, type_):
+        return "RAW(%(length)s)" % {'length' : type_.length}
+
+class OracleCompiler(compiler.SQLCompiler):
+    """Oracle compiler modifies the lexical structure of Select
+    statements to work under non-ANSI configured Oracle databases, if
+    the use_ansi flag is False.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super(OracleCompiler, self).__init__(*args, **kwargs)
+        self.__wheres = {}
+        self._quoted_bind_names = {}
+
+    def visit_mod(self, binary, **kw):
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+    
+    def visit_now_func(self, fn, **kw):
+        return "CURRENT_TIMESTAMP"
+    
+    def visit_char_length_func(self, fn, **kw):
+        return "LENGTH" + self.function_argspec(fn, **kw)
+        
+    def visit_match_op(self, binary, **kw):
+        return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+    
+    def function_argspec(self, fn, **kw):
+        if len(fn.clauses) > 0:
+            return compiler.SQLCompiler.function_argspec(self, fn, **kw)
+        else:
+            return ""
+        
+    def default_from(self):
+        """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
+
+        The Oracle compiler tacks a "FROM DUAL" to the statement.
+        """
+
+        return " FROM DUAL"
+
+    def visit_join(self, join, **kwargs):
+        if self.dialect.use_ansi:
+            return compiler.SQLCompiler.visit_join(self, join, **kwargs)
+        else:
+            return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+    def _get_nonansi_join_whereclause(self, froms):
+        clauses = []
+
+        def visit_join(join):
+            if join.isouter:
+                def visit_binary(binary):
+                    if binary.operator == sql_operators.eq:
+                        if binary.left.table is join.right:
+                            binary.left = _OuterJoinColumn(binary.left)
+                        elif binary.right.table is join.right:
+                            binary.right = _OuterJoinColumn(binary.right)
+                clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
+            else:
+                clauses.append(join.onclause)
+
+        for f in froms:
+            visitors.traverse(f, {}, {'join':visit_join})
+        return sql.and_(*clauses)
+
+    def visit_outer_join_column(self, vc):
+        return self.process(vc.column) + "(+)"
+
+    def visit_sequence(self, seq):
+        return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
+
+    def visit_alias(self, alias, asfrom=False, **kwargs):
+        """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
+
+        if asfrom:
+            alias_name = isinstance(alias.name, expression._generated_label) and \
+                            self._truncated_identifier("alias", alias.name) or alias.name
+            
+            return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, alias_name)
+        else:
+            return self.process(alias.original, **kwargs)
+
+    def returning_clause(self, stmt, returning_cols):
+            
+        def create_out_param(col, i):
+            bindparam = sql.outparam("ret_%d" % i, type_=col.type)
+            self.binds[bindparam.key] = bindparam
+            return self.bindparam_string(self._truncate_bindparam(bindparam))
+        
+        columnlist = list(expression._select_iterables(returning_cols))
+        
+        # within_columns_clause =False so that labels (foo AS bar) don't render
+        columns = [self.process(c, within_columns_clause=False) for c in columnlist]
+        
+        binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
+        
+        return 'RETURNING ' + ', '.join(columns) +  " INTO " + ", ".join(binds)
+
+    def _TODO_visit_compound_select(self, select):
+        """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+        pass
+
+    def visit_select(self, select, **kwargs):
+        """Look for ``LIMIT`` and OFFSET in a select statement, and if
+        so tries to wrap it in a subquery with ``rownum`` criterion.
+        """
+
+        if not getattr(select, '_oracle_visit', None):
+            if not self.dialect.use_ansi:
+                if self.stack and 'from' in self.stack[-1]:
+                    existingfroms = self.stack[-1]['from']
+                else:
+                    existingfroms = None
+
+                froms = select._get_display_froms(existingfroms)
+                whereclause = self._get_nonansi_join_whereclause(froms)
+                if whereclause:
+                    select = select.where(whereclause)
+                    select._oracle_visit = True
+
+            if select._limit is not None or select._offset is not None:
+                # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
+                #
+                # Generalized form of an Oracle pagination query:
+                #   select ... from (
+                #     select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
+                #         select distinct ... where ... order by ...
+                #     ) where ROWNUM <= :limit+:offset
+                #   ) where ora_rn > :offset
+                # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
+
+                # TODO: use annotations instead of clone + attr set ?
+                select = select._generate()
+                select._oracle_visit = True
+
+                # Wrap the middle select and add the hint
+                limitselect = sql.select([c for c in select.c])
+                if select._limit and self.dialect.optimize_limits:
+                    limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
+
+                limitselect._oracle_visit = True
+                limitselect._is_wrapper = True
+
+                # If needed, add the limiting clause
+                if select._limit is not None:
+                    max_row = select._limit
+                    if select._offset is not None:
+                        max_row += select._offset
+                    limitselect.append_whereclause(
+                            sql.literal_column("ROWNUM")<=max_row)
+
+                # If needed, add the ora_rn, and wrap again with offset.
+                if select._offset is None:
+                    select = limitselect
+                else:
+                     limitselect = limitselect.column(
+                             sql.literal_column("ROWNUM").label("ora_rn"))
+                     limitselect._oracle_visit = True
+                     limitselect._is_wrapper = True
+
+                     offsetselect = sql.select(
+                             [c for c in limitselect.c if c.key!='ora_rn'])
+                     offsetselect._oracle_visit = True
+                     offsetselect._is_wrapper = True
+
+                     offsetselect.append_whereclause(
+                             sql.literal_column("ora_rn")>select._offset)
+
+                     select = offsetselect
+
+        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
+        return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+    def limit_clause(self, select):
+        return ""
+
+    def for_update_clause(self, select):
+        if select.for_update == "nowait":
+            return " FOR UPDATE NOWAIT"
+        else:
+            return super(OracleCompiler, self).for_update_clause(select)
+
+class OracleDDLCompiler(compiler.DDLCompiler):
+
+    def visit_create_sequence(self, create):
+        return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+
+    def visit_drop_sequence(self, drop):
+        return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+    def define_constraint_cascades(self, constraint):
+        text = ""
+        if constraint.ondelete is not None:
+            text += " ON DELETE %s" % constraint.ondelete
+            
+        # oracle has no ON UPDATE CASCADE - 
+        # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html
+        if constraint.onupdate is not None:
+            util.warn(
+                "Oracle does not contain native UPDATE CASCADE "
+                 "functionality - onupdates will not be rendered for foreign keys."
+                 "Consider using deferrable=True, initially='deferred' or triggers.")
+        
+        return text
+
+class OracleDefaultRunner(base.DefaultRunner):
+    def visit_sequence(self, seq):
+        return self.execute_string("SELECT " + 
+                    self.dialect.identifier_preparer.format_sequence(seq) + 
+                    ".nextval FROM DUAL", ())
+
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
+    
+    reserved_words = set([x.lower() for x in RESERVED_WORDS])
+    illegal_initial_characters = re.compile(r'[0-9_$]')
+
+    def _bindparam_requires_quotes(self, value):
+        """Return True if the given identifier requires quoting."""
+        lc_value = value.lower()
+        return (lc_value in self.reserved_words
+                or self.illegal_initial_characters.match(value[0])
+                or not self.legal_characters.match(unicode(value))
+                )
+    
+    def format_savepoint(self, savepoint):
+        name = re.sub(r'^_+', '', savepoint.ident)
+        return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+        
+class OracleDialect(default.DefaultDialect):
+    name = 'oracle'
+    supports_alter = True
+    supports_unicode_statements = False
+    supports_unicode_binds = False
+    max_identifier_length = 30
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = False
+
+    supports_sequences = True
+    sequences_optional = False
+    postfetch_lastrowid = False
+    
+    default_paramstyle = 'named'
+    colspecs = colspecs
+    ischema_names = ischema_names
+    requires_name_normalize = True
+    
+    supports_default_values = False
+    supports_empty_insert = False
+    
+    statement_compiler = OracleCompiler
+    ddl_compiler = OracleDDLCompiler
+    type_compiler = OracleTypeCompiler
+    preparer = OracleIdentifierPreparer
+    defaultrunner = OracleDefaultRunner
+    
+    reflection_options = ('oracle_resolve_synonyms', )
+    
+    
+    def __init__(self, 
+                use_ansi=True, 
+                optimize_limits=False, 
+                **kwargs):
+        default.DefaultDialect.__init__(self, **kwargs)
+        self.use_ansi = use_ansi
+        self.optimize_limits = optimize_limits
+
+# TODO: implement server_version_info for oracle
+#    def initialize(self, connection):
+#        super(OracleDialect, self).initialize(connection)
+#        self.implicit_returning = self.server_version_info > (10, ) and \
+#                                        self.__dict__.get('implicit_returning', True)
+
+    def do_release_savepoint(self, connection, name):
+        # Oracle does not support RELEASE SAVEPOINT
+        pass
+
+    def has_table(self, connection, table_name, schema=None):
+        if not schema:
+            schema = self.get_default_schema_name(connection)
+        cursor = connection.execute(
+            sql.text("SELECT table_name FROM all_tables "
+                     "WHERE table_name = :name AND owner = :schema_name"),
+            name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema))
+        return cursor.fetchone() is not None
+
+    def has_sequence(self, connection, sequence_name, schema=None):
+        if not schema:
+            schema = self.get_default_schema_name(connection)
+        cursor = connection.execute(
+            sql.text("SELECT sequence_name FROM all_sequences "
+                     "WHERE sequence_name = :name AND sequence_owner = :schema_name"),
+            name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema))
+        return cursor.fetchone() is not None
+
+    def normalize_name(self, name):
+        if name is None:
+            return None
+        elif (name.upper() == name and
+              not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding))):
+            return name.lower().decode(self.encoding)
+        else:
+            return name.decode(self.encoding)
+
+    def denormalize_name(self, name):
+        if name is None:
+            return None
+        elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
+            return name.upper().encode(self.encoding)
+        else:
+            return name.encode(self.encoding)
+
+    def get_default_schema_name(self, connection):
+        return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
+
+    def table_names(self, connection, schema):
+        # note that table_names() isnt loading DBLINKed or synonym'ed tables
+        if schema is None:
+            cursor = connection.execute(
+                "SELECT table_name FROM all_tables "
+                "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')")
+        else:
+            s = sql.text(
+                "SELECT table_name FROM all_tables "
+                "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') "
+                "AND OWNER = :owner")
+            cursor = connection.execute(s, owner=self.denormalize_name(schema))
+        return [self.normalize_name(row[0]) for row in cursor]
+
+    def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
+        """search for a local synonym matching the given desired owner/name.
+
+        if desired_owner is None, attempts to locate a distinct owner.
+
+        returns the actual name, owner, dblink name, and synonym name if found.
+        """
+
+        q = "SELECT owner, table_owner, table_name, db_link, synonym_name FROM all_synonyms WHERE "
+        clauses = []
+        params = {}
+        if desired_synonym:
+            clauses.append("synonym_name = :synonym_name")
+            params['synonym_name'] = desired_synonym
+        if desired_owner:
+            clauses.append("table_owner = :desired_owner")
+            params['desired_owner'] = desired_owner
+        if desired_table:
+            clauses.append("table_name = :tname")
+            params['tname'] = desired_table
+
+        q += " AND ".join(clauses)
+
+        result = connection.execute(sql.text(q), **params)
+        if desired_owner:
+            row = result.fetchone()
+            if row:
+                return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name']
+            else:
+                return None, None, None, None
+        else:
+            rows = result.fetchall()
+            if len(rows) > 1:
+                raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
+            elif len(rows) == 1:
+                row = rows[0]
+                return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name']
+            else:
+                return None, None, None, None
+
+    @reflection.cache
+    def _prepare_reflection_args(self, connection, table_name, schema=None,
+                                 resolve_synonyms=False, dblink='', **kw):
+
+        if resolve_synonyms:
+            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self.denormalize_name(schema), desired_synonym=self.denormalize_name(table_name))
+        else:
+            actual_name, owner, dblink, synonym = None, None, None, None
+        if not actual_name:
+            actual_name = self.denormalize_name(table_name)
+        if not dblink:
+            dblink = ''
+        if not owner:
+            owner = self.denormalize_name(schema or self.get_default_schema_name(connection))
+        return (actual_name, owner, dblink, synonym)
+
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
+        s = "SELECT username FROM all_users ORDER BY username"
+        cursor = connection.execute(s,)
+        return [self.normalize_name(row[0]) for row in cursor]
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
+        return self.table_names(connection, schema)
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
+        s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
+        cursor = connection.execute(s, owner=self.denormalize_name(schema))
+        return [self.normalize_name(row[0]) for row in cursor]
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        """
+
+        kw arguments can be:
+
+            oracle_resolve_synonyms
+
+            dblink
+
+        """
+
+        resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
+        dblink = kw.get('dblink', '')
+        info_cache = kw.get('info_cache')
+
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
+                                          resolve_synonyms, dblink,
+                                          info_cache=info_cache)
+        columns = []
+        c = connection.execute(sql.text(
+                "SELECT column_name, data_type, data_length, data_precision, data_scale, "
+                "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "
+                "WHERE table_name = :table_name AND owner = :owner" % {'dblink': dblink}),
+                               table_name=table_name, owner=schema)
+
+        for row in c:
+            (colname, coltype, length, precision, scale, nullable, default) = \
+                (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+
+            # INTEGER if the scale is 0 and precision is null
+            # NUMBER if the scale and precision are both null
+            # NUMBER(9,2) if the precision is 9 and the scale is 2
+            # NUMBER(3) if the precision is 3 and scale is 0
+            #length is ignored except for CHAR and VARCHAR2
+            if coltype == 'NUMBER' :
+                if precision is None and scale is None:
+                    coltype = sqltypes.NUMERIC
+                elif precision is None and scale == 0:
+                    coltype = sqltypes.INTEGER
+                else :
+                    coltype = sqltypes.NUMERIC(precision, scale)
+            elif coltype=='CHAR' or coltype=='VARCHAR2':
+                coltype = self.ischema_names.get(coltype)(length)
+            else:
+                coltype = re.sub(r'\(\d+\)', '', coltype)
+                try:
+                    coltype = self.ischema_names[coltype]
+                except KeyError:
+                    util.warn("Did not recognize type '%s' of column '%s'" %
+                              (coltype, colname))
+                    coltype = sqltypes.NULLTYPE
+
+            cdict = {
+                'name': colname,
+                'type': coltype,
+                'nullable': nullable,
+                'default': default,
+            }
+            columns.append(cdict)
+        return columns
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None,
+                    resolve_synonyms=False, dblink='', **kw):
+
+        
+        info_cache = kw.get('info_cache')
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
+                                          resolve_synonyms, dblink,
+                                          info_cache=info_cache)
+        indexes = []
+        q = sql.text("""
+        SELECT a.index_name, a.column_name, b.uniqueness
+        FROM ALL_IND_COLUMNS%(dblink)s a
+        INNER JOIN ALL_INDEXES%(dblink)s b
+            ON a.index_name = b.index_name
+            AND a.table_owner = b.table_owner
+            AND a.table_name = b.table_name
+        WHERE a.table_name = :table_name
+        AND a.table_owner = :schema
+        ORDER BY a.index_name, a.column_position""" % {'dblink': dblink})
+        rp = connection.execute(q, table_name=self.denormalize_name(table_name),
+                                schema=self.denormalize_name(schema))
+        indexes = []
+        last_index_name = None
+        pkeys = self.get_primary_keys(connection, table_name, schema,
+                                      resolve_synonyms=resolve_synonyms,
+                                      dblink=dblink,
+                                      info_cache=kw.get('info_cache'))
+        uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
+        for rset in rp:
+            # don't include the primary key columns
+            if rset.column_name in [s.upper() for s in pkeys]:
+                continue
+            if rset.index_name != last_index_name:
+                index = dict(name=self.normalize_name(rset.index_name), column_names=[])
+                indexes.append(index)
+            index['unique'] = uniqueness.get(rset.uniqueness, False)
+            index['column_names'].append(self.normalize_name(rset.column_name))
+            last_index_name = rset.index_name
+        return indexes
+
+    @reflection.cache
+    def _get_constraint_data(self, connection, table_name, schema=None,
+                            dblink='', **kw):
+
+        rp = connection.execute(
+            sql.text("""SELECT
+             ac.constraint_name,
+             ac.constraint_type,
+             loc.column_name AS local_column,
+             rem.table_name AS remote_table,
+             rem.column_name AS remote_column,
+             rem.owner AS remote_owner,
+             loc.position as loc_pos,
+             rem.position as rem_pos
+           FROM all_constraints%(dblink)s ac,
+             all_cons_columns%(dblink)s loc,
+             all_cons_columns%(dblink)s rem
+           WHERE ac.table_name = :table_name
+           AND ac.constraint_type IN ('R','P')
+           AND ac.owner = :owner
+           AND ac.owner = loc.owner
+           AND ac.constraint_name = loc.constraint_name
+           AND ac.r_owner = rem.owner(+)
+           AND ac.r_constraint_name = rem.constraint_name(+)
+           AND (rem.position IS NULL or loc.position=rem.position)
+           ORDER BY ac.constraint_name, loc.position""" % {'dblink': dblink}),
+            table_name=table_name, owner=schema)
+        constraint_data = rp.fetchall()
+        return constraint_data
+
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        """
+
+        kw arguments can be:
+
+            oracle_resolve_synonyms
+
+            dblink
+
+        """
+
+        resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
+        dblink = kw.get('dblink', '')
+        info_cache = kw.get('info_cache')
+
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
+                                          resolve_synonyms, dblink,
+                                          info_cache=info_cache)
+        pkeys = []
+        constraint_data = self._get_constraint_data(connection, table_name,
+                                        schema, dblink,
+                                        info_cache=kw.get('info_cache'))
+                                        
+        for row in constraint_data:
+            #print "ROW:" , row
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+                row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+            if cons_type == 'P':
+                pkeys.append(local_column)
+        return pkeys
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+        """
+
+        kw arguments can be:
+
+            oracle_resolve_synonyms
+
+            dblink
+
+        """
+
+        requested_schema = schema # to check later on
+        resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
+        dblink = kw.get('dblink', '')
+        info_cache = kw.get('info_cache')
+
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
+                                          resolve_synonyms, dblink,
+                                          info_cache=info_cache)
+
+        constraint_data = self._get_constraint_data(connection, table_name,
+                                                schema, dblink,
+                                                info_cache=kw.get('info_cache'))
+
+        def fkey_rec():
+            return {
+                'name' : None,
+                'constrained_columns' : [],
+                'referred_schema' : None,
+                'referred_table' : None,
+                'referred_columns' : []
+            }
+
+        fkeys = util.defaultdict(fkey_rec)
+        
+        for row in constraint_data:
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+                    row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+
+            if cons_type == 'R':
+                if remote_table is None:
+                    # ticket 363
+                    util.warn(
+                        ("Got 'None' querying 'table_name' from "
+                         "all_cons_columns%(dblink)s - does the user have "
+                         "proper rights to the table?") % {'dblink':dblink})
+                    continue
+
+                rec = fkeys[cons_name]
+                rec['name'] = cons_name
+                local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+
+                if not rec['referred_table']:
+                    if resolve_synonyms:
+                        ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \
+                                self._resolve_synonym(
+                                    connection, 
+                                    desired_owner=self.denormalize_name(remote_owner), 
+                                    desired_table=self.denormalize_name(remote_table)
+                                )
+                        if ref_synonym:
+                            remote_table = self.normalize_name(ref_synonym)
+                            remote_owner = self.normalize_name(ref_remote_owner)
+                    
+                    rec['referred_table'] = remote_table
+                    
+                    if requested_schema is not None or self.denormalize_name(remote_owner) != schema:
+                        rec['referred_schema'] = remote_owner
+                
+                local_cols.append(local_column)
+                remote_cols.append(remote_column)
+
+        return fkeys.values()
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None,
+                            resolve_synonyms=False, dblink='', **kw):
+        info_cache = kw.get('info_cache')
+        (view_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, view_name, schema,
+                                          resolve_synonyms, dblink,
+                                          info_cache=info_cache)
+        s = sql.text("""
+        SELECT text FROM all_views
+        WHERE owner = :schema
+        AND view_name = :view_name
+        """)
+        rp = connection.execute(s,
+                                view_name=view_name, schema=schema).scalar()
+        if rp:
+            return rp.decode(self.encoding)
+        else:
+            return None
+
+
+
+class _OuterJoinColumn(sql.ClauseElement):
+    __visit_name__ = 'outer_join_column'
+    
+    def __init__(self, column):
+        self.column = column
+
+
+
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
new file mode 100644 (file)
index 0000000..d8a0c44
--- /dev/null
@@ -0,0 +1,371 @@
+"""Support for the Oracle database via the cx_oracle driver.
+
+Driver
+------
+
+The Oracle dialect uses the cx_oracle driver, available at 
+http://cx-oracle.sourceforge.net/ .   The dialect has several behaviors 
+which are specifically tailored towards compatibility with this module.
+
+Connecting
+----------
+
+Connecting with create_engine() uses the standard URL approach of 
+``oracle://user:pass@host:port/dbname[?key=value&key=value...]``.  If dbname is present, the 
+host, port, and dbname tokens are converted to a TNS name using the cx_oracle 
+:func:`makedsn()` function.  Otherwise, the host token is taken directly as a TNS name.
+
+Additional arguments which may be specified either as query string arguments on the
+URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
+
+* *allow_twophase* - enable two-phase transactions.  Defaults to ``True``.
+
+* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
+
+* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
+  This is required for LOB datatypes but can be disabled to reduce overhead.  Defaults
+  to ``True``.
+
+* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
+  integer value.  This value is only available as a URL query string argument.
+
+* *threaded* - enable multithreaded access to cx_oracle connections.  Defaults
+  to ``True``.  Note that this is the opposite default of cx_oracle itself.
+
+
+LOB Objects
+-----------
+
+cx_oracle presents some challenges when fetching LOB objects.  A LOB object in a result set
+is presented by cx_oracle as a cx_oracle.LOB object which has a read() method.  By default, 
+SQLAlchemy converts these LOB objects into Python strings.  This is for two reasons.  First,
+the LOB object requires an active cursor association, meaning if you were to fetch many rows
+at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
+the LOB objects in the already-fetched rows are now unreadable and will raise an error. 
+SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.  
+The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
+defaults to 50 (cx_oracle normally defaults this to one).  
+
+Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to 
+"normalize" the results to look more like other DBAPIs.
+
+The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
+for all statement executions, even plain string-based statements for which SQLA has no awareness
+of result typing.  This is so that calls like fetchmany() and fetchall() can work in all cases
+without raising cursor errors.  The conversion of LOB in all cases, as well as the "prefetch"
+of LOB objects, can be disabled using auto_convert_lobs=False.  
+
+Two Phase Transaction Support
+-----------------------------
+
+Two Phase transactions are implemented using XA transactions.  Success has been reported of them
+working successfully but this should be regarded as an experimental feature.
+
+"""
+
+from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, RESERVED_WORDS
+from sqlalchemy.dialects.oracle import base as oracle
+from sqlalchemy.engine.default import DefaultExecutionContext
+from sqlalchemy.engine import base
+from sqlalchemy import types as sqltypes, util
+import datetime
+
+class _OracleDate(sqltypes.Date):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        def process(value):
+            if not isinstance(value, datetime.datetime):
+                return value
+            else:
+                return value.date()
+        return process
+
+class _OracleDateTime(sqltypes.DateTime):
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None or isinstance(value, datetime.datetime):
+                return value
+            else:
+                # convert cx_oracle datetime object returned pre-python 2.4
+                return datetime.datetime(value.year, value.month,
+                    value.day,value.hour, value.minute, value.second)
+        return process
+
+# Note:
+# Oracle DATE == DATETIME
+# Oracle does not allow milliseconds in DATE
+# Oracle does not support TIME columns
+
+# only if cx_oracle contains TIMESTAMP
+class _OracleTimestamp(sqltypes.TIMESTAMP):
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None or isinstance(value, datetime.datetime):
+                return value
+            else:
+                # convert cx_oracle datetime object returned pre-python 2.4
+                return datetime.datetime(value.year, value.month,
+                    value.day,value.hour, value.minute, value.second)
+        return process
+
+class _LOBMixin(object):
+    def result_processor(self, dialect):
+        super_process = super(_LOBMixin, self).result_processor(dialect)
+        if not dialect.auto_convert_lobs:
+            return super_process
+        lob = dialect.dbapi.LOB
+        def process(value):
+            if isinstance(value, lob):
+                if super_process:
+                    return super_process(value.read())
+                else:
+                    return value.read()
+            else:
+                if super_process:
+                    return super_process(value)
+                else:
+                    return value
+        return process
+    
+class _OracleText(_LOBMixin, sqltypes.Text):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.CLOB
+
+class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.NCLOB
+
+
+class _OracleBinary(_LOBMixin, sqltypes.Binary):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.BLOB
+
+    def bind_processor(self, dialect):
+        return None
+
+
+class _OracleRaw(_LOBMixin, oracle.RAW):
+    pass
+
+
+colspecs = {
+    sqltypes.DateTime : _OracleDateTime,
+    sqltypes.Date : _OracleDate,
+    sqltypes.Binary : _OracleBinary,
+    sqltypes.Boolean : oracle._OracleBoolean,
+    sqltypes.Text : _OracleText,
+    sqltypes.UnicodeText : _OracleUnicodeText,
+    sqltypes.TIMESTAMP : _OracleTimestamp,
+    oracle.RAW: _OracleRaw,
+}
+
+class Oracle_cx_oracleCompiler(OracleCompiler):
+    def bindparam_string(self, name):
+        if self.preparer._bindparam_requires_quotes(name):
+            quoted_name = '"%s"' % name
+            self._quoted_bind_names[name] = quoted_name
+            return OracleCompiler.bindparam_string(self, quoted_name)
+        else:
+            return OracleCompiler.bindparam_string(self, name)
+
+class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
+    def pre_exec(self):
+        quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {})
+        if quoted_bind_names:
+            for param in self.parameters:
+                for fromname, toname in self.compiled._quoted_bind_names.iteritems():
+                    param[toname.encode(self.dialect.encoding)] = param[fromname]
+                    del param[fromname]
+
+        if self.dialect.auto_setinputsizes:
+            self.set_input_sizes(quoted_bind_names, exclude_types=(self.dialect.dbapi.STRING,))
+            
+        if len(self.compiled_parameters) == 1:
+            for key in self.compiled.binds:
+                bindparam = self.compiled.binds[key]
+                name = self.compiled.bind_names[bindparam]
+                value = self.compiled_parameters[0][name]
+                if bindparam.isoutparam:
+                    dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+                    if not hasattr(self, 'out_parameters'):
+                        self.out_parameters = {}
+                    self.out_parameters[name] = self.cursor.var(dbtype)
+                    self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name]
+        
+        
+    def create_cursor(self):
+        c = self._connection.connection.cursor()
+        if self.dialect.arraysize:
+            c.cursor.arraysize = self.dialect.arraysize
+        return c
+
+    def get_result_proxy(self):
+        if hasattr(self, 'out_parameters'):
+            if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
+                for bind, name in self.compiled.bind_names.iteritems():
+                    if name in self.out_parameters:
+                        type = bind.type
+                        result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
+                        if result_processor is not None:
+                            self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
+                        else:
+                            self.out_parameters[name] = self.out_parameters[name].getvalue()
+            else:
+                for k in self.out_parameters:
+                    self.out_parameters[k] = self.out_parameters[k].getvalue()
+
+        if self.cursor.description is not None:
+            for column in self.cursor.description:
+                type_code = column[1]
+                if type_code in self.dialect.ORACLE_BINARY_TYPES:
+                    return base.BufferedColumnResultProxy(self)
+        
+        if hasattr(self, 'out_parameters') and \
+            self.compiled.returning:
+                
+            return ReturningResultProxy(self)
+        else:
+            return base.ResultProxy(self)
+
+class ReturningResultProxy(base.FullyBufferedResultProxy):
+    """Result proxy which stuffs the _returning clause + outparams into the fetch."""
+    
+    def _cursor_description(self):
+        returning = self.context.compiled.returning
+        
+        ret = []
+        for c in returning:
+            if hasattr(c, 'key'):
+                ret.append((c.key, c.type))
+            else:
+                ret.append((c.anon_label, c.type))
+        return ret
+    
+    def _buffer_rows(self):
+        returning = self.context.compiled.returning
+        return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))]
+
+class Oracle_cx_oracle(OracleDialect):
+    execution_ctx_cls = Oracle_cx_oracleExecutionContext
+    statement_compiler = Oracle_cx_oracleCompiler
+    driver = "cx_oracle"
+    colspecs = colspecs
+    
+    def __init__(self, 
+                auto_setinputsizes=True, 
+                auto_convert_lobs=True, 
+                threaded=True, 
+                allow_twophase=True, 
+                arraysize=50, **kwargs):
+        OracleDialect.__init__(self, **kwargs)
+        self.threaded = threaded
+        self.arraysize = arraysize
+        self.allow_twophase = allow_twophase
+        self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
+        self.auto_setinputsizes = auto_setinputsizes
+        self.auto_convert_lobs = auto_convert_lobs
+        
+        def vers(num):
+            return tuple([int(x) for x in num.split('.')])
+
+        if hasattr(self.dbapi, 'version'):
+            cx_oracle_ver = vers(self.dbapi.version)
+            self.supports_unicode_binds = cx_oracle_ver >= (5, 0)
+        
+        if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
+            self.dbapi_type_map = {}
+            self.ORACLE_BINARY_TYPES = []
+        else:
+            # only use this for LOB objects.  using it for strings, dates
+            # etc. leads to a little too much magic, reflection doesn't know if it should
+            # expect encoded strings or unicodes, etc.
+            self.dbapi_type_map = {
+                self.dbapi.CLOB: oracle.CLOB(),
+                self.dbapi.NCLOB:oracle.NCLOB(),
+                self.dbapi.BLOB: oracle.BLOB(),
+                self.dbapi.BINARY: oracle.RAW(),
+            }
+            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
+    
+    @classmethod
+    def dbapi(cls):
+        import cx_Oracle
+        return cx_Oracle
+
+    def create_connect_args(self, url):
+        dialect_opts = dict(url.query)
+        for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
+                    'threaded', 'allow_twophase'):
+            if opt in dialect_opts:
+                util.coerce_kw_type(dialect_opts, opt, bool)
+                setattr(self, opt, dialect_opts[opt])
+
+        if url.database:
+            # if we have a database, then we have a remote host
+            port = url.port
+            if port:
+                port = int(port)
+            else:
+                port = 1521
+            dsn = self.dbapi.makedsn(url.host, port, url.database)
+        else:
+            # we have a local tnsname
+            dsn = url.host
+
+        opts = dict(
+            user=url.username,
+            password=url.password,
+            dsn=dsn,
+            threaded=self.threaded,
+            twophase=self.allow_twophase,
+            )
+        if 'mode' in url.query:
+            opts['mode'] = url.query['mode']
+            if isinstance(opts['mode'], basestring):
+                mode = opts['mode'].upper()
+                if mode == 'SYSDBA':
+                    opts['mode'] = self.dbapi.SYSDBA
+                elif mode == 'SYSOPER':
+                    opts['mode'] = self.dbapi.SYSOPER
+                else:
+                    util.coerce_kw_type(opts, 'mode', int)
+        # Can't set 'handle' or 'pool' via URL query args, use connect_args
+
+        return ([], opts)
+
+    def _get_server_version_info(self, connection):
+        return tuple(int(x) for x in connection.connection.version.split('.'))
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.InterfaceError):
+            return "not connected" in str(e)
+        else:
+            return "ORA-03114" in str(e) or "ORA-03113" in str(e)
+
+    def create_xid(self):
+        """create a two-phase transaction ID.
+
+        this id will be passed to do_begin_twophase(), do_rollback_twophase(),
+        do_commit_twophase().  its format is unspecified."""
+
+        id = random.randint(0, 2 ** 128)
+        return (0x1234, "%032x" % id, "%032x" % 9)
+
+    def do_begin_twophase(self, connection, xid):
+        connection.connection.begin(*xid)
+
+    def do_prepare_twophase(self, connection, xid):
+        connection.connection.prepare()
+
+    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+        self.do_rollback(connection.connection)
+
+    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+        self.do_commit(connection.connection)
+
+    def do_recover_twophase(self, connection):
+        pass
+
+dialect = Oracle_cx_oracle
diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py
new file mode 100644 (file)
index 0000000..a0ad088
--- /dev/null
@@ -0,0 +1,24 @@
+"""Support for the Oracle database via the zxjdbc JDBC connector."""
+import re
+
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.oracle.base import OracleDialect
+
+class Oracle_jdbc(ZxJDBCConnector, OracleDialect):
+
+    jdbc_db_name = 'oracle'
+    jdbc_driver_name = 'oracle.jdbc.driver.OracleDriver'
+
+    def create_connect_args(self, url):
+        hostname = url.host
+        port = url.port or '1521'
+        dbname = url.database
+        jdbc_url = 'jdbc:oracle:thin:@%s:%s:%s' % (hostname, port, dbname)
+        return [[jdbc_url, url.username, url.password, self.jdbc_driver_name],
+                self._driver_kwargs()]
+        
+    def _get_server_version_info(self, connection):
+        version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
+        return tuple(int(x) for x in version.split('.'))
+        
+dialect = Oracle_jdbc
diff --git a/lib/sqlalchemy/dialects/postgres.py b/lib/sqlalchemy/dialects/postgres.py
new file mode 100644 (file)
index 0000000..e66989f
--- /dev/null
@@ -0,0 +1,9 @@
+# backwards compat with the old name
+from sqlalchemy.util import warn_deprecated
+
+warn_deprecated(
+    "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. "
+    "The new URL format is postgresql[+driver]://<user>:<pass>@<host>/<dbname>"
+    )
+    
+from sqlalchemy.dialects.postgresql import *
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
new file mode 100644 (file)
index 0000000..af9430a
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, zxjdbc
+
+base.dialect = psycopg2.dialect
\ No newline at end of file
similarity index 51%
rename from lib/sqlalchemy/databases/postgres.py
rename to lib/sqlalchemy/dialects/postgresql/base.py
index 154d971e359eb2c3891f49780acc50ccb61f110d..874907abc1299c8687863f7076ce57c3a7f70dc3 100644 (file)
@@ -1,34 +1,13 @@
-# postgres.py
+# postgresql.py
 # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Support for the PostgreSQL database.
+"""Support for the PostgreSQL database.  
 
-Driver
-------
-
-The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
-The dialect has several behaviors  which are specifically tailored towards compatibility 
-with this module.
-
-Note that psycopg1 is **not** supported.
-
-Connecting
-----------
-
-URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`.
-
-PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
-
-* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
-  this feature.  What this essentially means from a psycopg2 point of view is that the cursor is 
-  created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
-  are not immediately pre-fetched and buffered after statement execution, but are instead left 
-  on the server and only retrieved as needed.    SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
-  uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows 
-  at a time are fetched over the wire to reduce conversational overhead.
+For information on connecting using specific drivers, see the documentation section
+regarding that driver.
 
 Sequences/SERIAL
 ----------------
@@ -64,144 +43,76 @@ The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` sy
 but must be explicitly enabled on a per-statement basis::
 
     # INSERT..RETURNING
-    result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
+    result = table.insert(postgresql_returning=[table.c.col1, table.c.col2]).\\
         values(name='foo')
     print result.fetchall()
     
     # UPDATE..RETURNING
-    result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
+    result = table.update(postgresql_returning=[table.c.col1, table.c.col2]).\\
         where(table.c.name=='foo').values(name='bar')
     print result.fetchall()
 
 Indexes
 -------
 
-PostgreSQL supports partial indexes. To create them pass a postgres_where
+PostgreSQL supports partial indexes. To create them pass a postgresql_where
 option to the Index constructor::
 
-  Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
+  Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10)
 
-Transactions
-------------
-
-The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations.
 
 
 """
 
-import decimal, random, re, string
+import re
 
+from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql, schema, exc, util
-from sqlalchemy.engine import base, default
+from sqlalchemy.engine import base, default, reflection
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
 
+from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \
+        CHAR, TEXT, FLOAT, NUMERIC, \
+        TIMESTAMP, TIME, DATE, BOOLEAN
 
-class PGInet(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "INET"
-
-class PGCidr(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "CIDR"
-
-class PGMacAddr(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "MACADDR"
-
-class PGNumeric(sqltypes.Numeric):
-    def get_col_spec(self):
-        if not self.precision:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
-        if self.asdecimal:
-            return None
-        else:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-
-class PGFloat(sqltypes.Float):
-    def get_col_spec(self):
-        if not self.precision:
-            return "FLOAT"
-        else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class PGInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class PGSmallInteger(sqltypes.Smallinteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class PGBigInteger(PGInteger):
-    def get_col_spec(self):
-        return "BIGINT"
-
-class PGDateTime(sqltypes.DateTime):
-    def get_col_spec(self):
-        return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+class REAL(sqltypes.Float):
+    __visit_name__ = "REAL"
 
-class PGDate(sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
+class BYTEA(sqltypes.Binary):
+    __visit_name__ = 'BYTEA'
 
-class PGTime(sqltypes.Time):
-    def get_col_spec(self):
-        return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+class DOUBLE_PRECISION(sqltypes.Float):
+    __visit_name__ = 'DOUBLE_PRECISION'
+    
+class INET(sqltypes.TypeEngine):
+    __visit_name__ = "INET"
+PGInet = INET
 
-class PGInterval(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "INTERVAL"
+class CIDR(sqltypes.TypeEngine):
+    __visit_name__ = "CIDR"
+PGCidr = CIDR
 
-class PGText(sqltypes.Text):
-    def get_col_spec(self):
-        return "TEXT"
+class MACADDR(sqltypes.TypeEngine):
+    __visit_name__ = "MACADDR"
+PGMacAddr = MACADDR
 
-class PGString(sqltypes.String):
-    def get_col_spec(self):
-        if self.length:
-            return "VARCHAR(%(length)d)" % {'length' : self.length}
-        else:
-            return "VARCHAR"
+class INTERVAL(sqltypes.TypeEngine):
+    __visit_name__ = 'INTERVAL'
+PGInterval = INTERVAL
 
-class PGChar(sqltypes.CHAR):
-    def get_col_spec(self):
-        if self.length:
-            return "CHAR(%(length)d)" % {'length' : self.length}
-        else:
-            return "CHAR"
+class BIT(sqltypes.TypeEngine):
+    __visit_name__ = 'BIT'
+PGBit = BIT
 
-class PGBinary(sqltypes.Binary):
-    def get_col_spec(self):
-        return "BYTEA"
+class UUID(sqltypes.TypeEngine):
+    __visit_name__ = 'UUID'
+PGUuid = UUID
 
-class PGBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BOOLEAN"
-
-class PGBit(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "BIT"
-        
-class PGUuid(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "UUID"
+class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
+    __visit_name__ = 'ARRAY'
     
-class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
     def __init__(self, item_type, mutable=True):
         if isinstance(item_type, type):
             item_type = item_type()
@@ -259,133 +170,341 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
                         return item
             return [convert_item(item) for item in value]
         return process
-    def get_col_spec(self):
-        return self.item_type.get_col_spec() + '[]'
+PGArray = ARRAY
 
 colspecs = {
-    sqltypes.Integer : PGInteger,
-    sqltypes.Smallinteger : PGSmallInteger,
-    sqltypes.Numeric : PGNumeric,
-    sqltypes.Float : PGFloat,
-    sqltypes.DateTime : PGDateTime,
-    sqltypes.Date : PGDate,
-    sqltypes.Time : PGTime,
-    sqltypes.String : PGString,
-    sqltypes.Binary : PGBinary,
-    sqltypes.Boolean : PGBoolean,
-    sqltypes.Text : PGText,
-    sqltypes.CHAR: PGChar,
+    sqltypes.Interval:INTERVAL
 }
 
 ischema_names = {
-    'integer' : PGInteger,
-    'bigint' : PGBigInteger,
-    'smallint' : PGSmallInteger,
-    'character varying' : PGString,
-    'character' : PGChar,
-    '"char"' : PGChar,
-    'name': PGChar,
-    'text' : PGText,
-    'numeric' : PGNumeric,
-    'float' : PGFloat,
-    'real' : PGFloat,
-    'inet': PGInet,
-    'cidr': PGCidr,
-    'uuid':PGUuid,
-    'bit':PGBit,
-    'macaddr': PGMacAddr,
-    'double precision' : PGFloat,
-    'timestamp' : PGDateTime,
-    'timestamp with time zone' : PGDateTime,
-    'timestamp without time zone' : PGDateTime,
-    'time with time zone' : PGTime,
-    'time without time zone' : PGTime,
-    'date' : PGDate,
-    'time': PGTime,
-    'bytea' : PGBinary,
-    'boolean' : PGBoolean,
-    'interval':PGInterval,
+    'integer' : INTEGER,
+    'bigint' : BIGINT,
+    'smallint' : SMALLINT,
+    'character varying' : VARCHAR,
+    'character' : CHAR,
+    '"char"' : sqltypes.String,
+    'name' : sqltypes.String,
+    'text' : TEXT,
+    'numeric' : NUMERIC,
+    'float' : FLOAT,
+    'real' : REAL,
+    'inet': INET,
+    'cidr': CIDR,
+    'uuid': UUID,
+    'bit':BIT,
+    'macaddr': MACADDR,
+    'double precision' : DOUBLE_PRECISION,
+    'timestamp' : TIMESTAMP,
+    'timestamp with time zone' : TIMESTAMP,
+    'timestamp without time zone' : TIMESTAMP,
+    'time with time zone' : TIME,
+    'time without time zone' : TIME,
+    'date' : DATE,
+    'time': TIME,
+    'bytea' : BYTEA,
+    'boolean' : BOOLEAN,
+    'interval':INTERVAL,
 }
 
-# TODO: filter out 'FOR UPDATE' statements
-SERVER_SIDE_CURSOR_RE = re.compile(
-    r'\s*SELECT',
-    re.I | re.UNICODE)
-
-class PGExecutionContext(default.DefaultExecutionContext):
-    def create_cursor(self):
-        # TODO: coverage for server side cursors + select.for_update()
-        is_server_side = \
-            self.dialect.server_side_cursors and \
-            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) 
-                and not getattr(self.compiled.statement, 'for_update', False)) \
-            or \
-            (
-                (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) 
-                and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
-            )
 
-        self.__is_server_side = is_server_side
-        if is_server_side:
-            # use server-side cursors:
-            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
-            return self._connection.connection.cursor(ident)
+
+class PGCompiler(compiler.SQLCompiler):
+    
+    def visit_match_op(self, binary, **kw):
+        return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right))
+
+    def visit_ilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+    def visit_notilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+    def post_process_text(self, text):
+        if '%%' in text:
+            util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() expressions to '%%'.")
+        return text.replace('%', '%%')
+
+    def visit_sequence(self, seq):
+        if seq.optional:
+            return None
+        else:
+            return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+    def limit_clause(self, select):
+        text = ""
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
+                text += " \n LIMIT ALL"
+            text += " OFFSET " + str(select._offset)
+        return text
+
+    def get_select_precolumns(self, select):
+        if select._distinct:
+            if isinstance(select._distinct, bool):
+                return "DISTINCT "
+            elif isinstance(select._distinct, (list, tuple)):
+                return "DISTINCT ON (" + ', '.join(
+                    [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
+                )+ ") "
+            else:
+                return "DISTINCT ON (" + unicode(select._distinct) + ") "
+        else:
+            return ""
+
+    def for_update_clause(self, select):
+        if select.for_update == 'nowait':
+            return " FOR UPDATE NOWAIT"
+        else:
+            return super(PGCompiler, self).for_update_clause(select)
+
+    def returning_clause(self, stmt, returning_cols):
+        
+        columns = [
+                self.process(
+                    self.label_select_column(None, c, asfrom=False), 
+                    within_columns_clause=True, 
+                    result_map=self.result_map) 
+                for c in expression._select_iterables(returning_cols)
+            ]
+            
+        return 'RETURNING ' + ', '.join(columns)
+
+    def visit_extract(self, extract, **kwargs):
+        field = self.extract_map.get(extract.field, extract.field)
+        return "EXTRACT(%s FROM %s::timestamp)" % (
+            field, self.process(extract.expr))
+
+class PGDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column)
+        if column.primary_key and \
+            len(column.foreign_keys)==0 and \
+            column.autoincrement and \
+            isinstance(column.type, sqltypes.Integer) and \
+            not isinstance(column.type, sqltypes.SmallInteger) and \
+            (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+            if isinstance(column.type, sqltypes.BigInteger):
+                colspec += " BIGSERIAL"
+            else:
+                colspec += " SERIAL"
         else:
-            return self._connection.connection.cursor()
+            colspec += " " + self.dialect.type_compiler.process(column.type)
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
+
+    def visit_create_sequence(self, create):
+        return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+            
+    def visit_drop_sequence(self, drop):
+        return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+        
+    def visit_create_index(self, create):
+        preparer = self.preparer
+        index = create.element
+        text = "CREATE "
+        if index.unique:
+            text += "UNIQUE "
+        text += "INDEX %s ON %s (%s)" \
+                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+                       preparer.format_table(index.table),
+                       ', '.join([preparer.format_column(c) for c in index.columns]))
+        
+        if "postgres_where" in index.kwargs:
+            whereclause = index.kwargs['postgres_where']
+            util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.")
+        elif 'postgresql_where' in index.kwargs:
+            whereclause = index.kwargs['postgresql_where']
+        else:
+            whereclause = None
+            
+        if whereclause is not None:
+            compiler = self._compile(whereclause, None)
+            # this might belong to the compiler class
+            inlined_clause = str(compiler) % dict(
+                [(key,bind.value) for key,bind in compiler.binds.iteritems()])
+            text += " WHERE " + inlined_clause
+        return text
+
+
+class PGDefaultRunner(base.DefaultRunner):
+    def __init__(self, context):
+        base.DefaultRunner.__init__(self, context)
+        # craete cursor which won't conflict with a server-side cursor
+        self.cursor = context._connection.connection.cursor()
     
-    def get_result_proxy(self):
-        if self.__is_server_side:
-            return base.BufferedRowResultProxy(self)
+    def get_column_default(self, column, isinsert=True):
+        if column.primary_key:
+            # pre-execute passive defaults on primary keys
+            if (isinstance(column.server_default, schema.DefaultClause) and
+                column.server_default.arg is not None):
+                return self.execute_string("select %s" % column.server_default.arg)
+            elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) \
+                    and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+                sch = column.table.schema
+                # TODO: this has to build into the Sequence object so we can get the quoting
+                # logic from it
+                if sch is not None:
+                    exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
+                else:
+                    exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
+
+                if self.dialect.supports_unicode_statements:
+                    return self.execute_string(exc)
+                else:
+                    return self.execute_string(exc.encode(self.dialect.encoding))
+
+        return super(PGDefaultRunner, self).get_column_default(column)
+
+    def visit_sequence(self, seq):
+        if not seq.optional:
+            return self.execute_string(("select nextval('%s')" % \
+                        self.dialect.identifier_preparer.format_sequence(seq)))
         else:
-            return base.ResultProxy(self)
+            return None
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_INET(self, type_):
+        return "INET"
+
+    def visit_CIDR(self, type_):
+        return "CIDR"
+
+    def visit_MACADDR(self, type_):
+        return "MACADDR"
+
+    def visit_FLOAT(self, type_):
+        if not type_.precision:
+            return "FLOAT"
+        else:
+            return "FLOAT(%(precision)s)" % {'precision': type_.precision}
+    
+    def visit_DOUBLE_PRECISION(self, type_):
+        return "DOUBLE PRECISION"
+        
+    def visit_BIGINT(self, type_):
+        return "BIGINT"
+
+    def visit_datetime(self, type_):
+        return self.visit_TIMESTAMP(type_)
+        
+    def visit_TIMESTAMP(self, type_):
+        return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+    def visit_TIME(self, type_):
+        return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+    def visit_INTERVAL(self, type_):
+        return "INTERVAL"
+
+    def visit_BIT(self, type_):
+        return "BIT"
+
+    def visit_UUID(self, type_):
+        return "UUID"
+
+    def visit_binary(self, type_):
+        return self.visit_BYTEA(type_)
+        
+    def visit_BYTEA(self, type_):
+        return "BYTEA"
+
+    def visit_REAL(self, type_):
+        return "REAL"
+
+    def visit_ARRAY(self, type_):
+        return self.process(type_.item_type) + '[]'
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+    def _unquote_identifier(self, value):
+        if value[0] == self.initial_quote:
+            value = value[1:-1].replace('""','"')
+        return value
+
+class PGInspector(reflection.Inspector):
+
+    def __init__(self, conn):
+        reflection.Inspector.__init__(self, conn)
+
+    def get_table_oid(self, table_name, schema=None):
+        """Return the oid from `table_name` and `schema`."""
+
+        return self.dialect.get_table_oid(self.conn, table_name, schema,
+                                          info_cache=self.info_cache)
+    
 
 class PGDialect(default.DefaultDialect):
-    name = 'postgres'
+    name = 'postgresql'
     supports_alter = True
-    supports_unicode_statements = False
     max_identifier_length = 63
     supports_sane_rowcount = True
-    supports_sane_multi_rowcount = False
-    preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
-    default_paramstyle = 'pyformat'
+    
+    supports_sequences = True
+    sequences_optional = True
+    preexecute_autoincrement_sequences = True
+    postfetch_lastrowid = False
+    
     supports_default_values = True
     supports_empty_insert = False
+    default_paramstyle = 'pyformat'
+    ischema_names = ischema_names
+    colspecs = colspecs
     
-    def __init__(self, server_side_cursors=False, **kwargs):
+    statement_compiler = PGCompiler
+    ddl_compiler = PGDDLCompiler
+    type_compiler = PGTypeCompiler
+    preparer = PGIdentifierPreparer
+    defaultrunner = PGDefaultRunner
+    inspector = PGInspector
+    isolation_level = None
+
+    def __init__(self, isolation_level=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
-        self.server_side_cursors = server_side_cursors
-
-    def dbapi(cls):
-        import psycopg2 as psycopg
-        return psycopg
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        if 'port' in opts:
-            opts['port'] = int(opts['port'])
-        opts.update(url.query)
-        return ([], opts)
+        self.isolation_level = isolation_level
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
+    def initialize(self, connection):
+        super(PGDialect, self).initialize(connection)
+        self.implicit_returning = self.server_version_info > (8, 3) and \
+                                        self.__dict__.get('implicit_returning', True)
+        
+    def visit_pool(self, pool):
+        if self.isolation_level is not None:
+            class SetIsolationLevel(object):
+                def __init__(self, isolation_level):
+                    self.isolation_level = isolation_level
+
+                def connect(self, conn, rec):
+                    cursor = conn.cursor()
+                    cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s"
+                                   % self.isolation_level)
+                    cursor.execute("COMMIT")
+                    cursor.close()
+            pool.add_listener(SetIsolationLevel(self.isolation_level))
 
     def do_begin_twophase(self, connection, xid):
         self.do_begin(connection.connection)
 
     def do_prepare_twophase(self, connection, xid):
-        connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)]))
+        connection.execute("PREPARE TRANSACTION '%s'" % xid)
 
     def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
         if is_prepared:
             if recover:
                 #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
                 # Must find out a way how to make the dbapi not open a transaction.
-                connection.execute(sql.text("ROLLBACK"))
-            connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
-            connection.execute(sql.text("BEGIN"))
+                connection.execute("ROLLBACK")
+            connection.execute("ROLLBACK PREPARED '%s'" % xid)
+            connection.execute("BEGIN")
             self.do_rollback(connection.connection)
         else:
             self.do_rollback(connection.connection)
@@ -393,9 +512,9 @@ class PGDialect(default.DefaultDialect):
     def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
         if is_prepared:
             if recover:
-                connection.execute(sql.text("ROLLBACK"))
-            connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
-            connection.execute(sql.text("BEGIN"))
+                connection.execute("ROLLBACK")
+            connection.execute("COMMIT PREPARED '%s'" % xid)
+            connection.execute("BEGIN")
             self.do_rollback(connection.connection)
         else:
             self.do_commit(connection.connection)
@@ -405,66 +524,151 @@ class PGDialect(default.DefaultDialect):
         return [row[0] for row in resultset]
 
     def get_default_schema_name(self, connection):
-        return connection.scalar("select current_schema()", None)
-    get_default_schema_name = base.connection_memoize(
-        ('dialect', 'default_schema_name'))(get_default_schema_name)
-
-    def last_inserted_ids(self):
-        if self.context.last_inserted_ids is None:
-            raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled")
-        else:
-            return self.context.last_inserted_ids
+        return connection.scalar("select current_schema()")
 
     def has_table(self, connection, table_name, schema=None):
         # seems like case gets folded in pg_class...
         if schema is None:
-            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
+            cursor = connection.execute(
+                sql.text("select relname from pg_class c join pg_namespace n on "
+                    "n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name",
+                    bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)]
+                )
+            )
         else:
-            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
-        return bool( not not cursor.rowcount )
+            cursor = connection.execute(
+                sql.text("select relname from pg_class c join pg_namespace n on "
+                        "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name",
+                    bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode),
+                        sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] 
+                )
+            )
+        return bool(cursor.fetchone())
 
     def has_sequence(self, connection, sequence_name):
-        cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)})
-        return bool(not not cursor.rowcount)
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.OperationalError):
-            return 'closed the connection' in str(e) or 'connection not open' in str(e)
-        elif isinstance(e, self.dbapi.InterfaceError):
-            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
-        elif isinstance(e, self.dbapi.ProgrammingError):
-            # yes, it really says "losed", not "closed"
-            return "losed the connection unexpectedly" in str(e)
-        else:
-            return False
+        cursor = connection.execute(
+                    sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND "
+                        "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' "
+                        "AND nspname != 'information_schema' AND relname = :seqname)", 
+                        bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)]
+                    ))
+        return bool(cursor.fetchone())
 
     def table_names(self, connection, schema):
-        s = """
-        SELECT relname
-        FROM pg_class c
-        WHERE relkind = 'r'
-          AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
-        """ % locals()
-        return [row[0].decode(self.encoding) for row in connection.execute(s)]
+        result = connection.execute(
+            sql.text(u"""SELECT relname
+                FROM pg_class c
+                WHERE relkind = 'r'
+                AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema,
+                typemap = {'relname':sqltypes.Unicode}
+            )
+        )
+        return [row[0] for row in result]
 
-    def server_version_info(self, connection):
+    def _get_server_version_info(self, connection):
         v = connection.execute("select version()").scalar()
         m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
         if not m:
             raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-        if table.schema is not None:
+    @reflection.cache
+    def get_table_oid(self, connection, table_name, schema=None, **kw):
+        """Fetch the oid for schema.table_name.
+
+        Several reflection methods require the table oid.  The idea for using
+        this method is that it can be fetched one time and cached for
+        subsequent calls.
+
+        """
+        table_oid = None
+        if schema is not None:
             schema_where_clause = "n.nspname = :schema"
-            schemaname = table.schema
-            if isinstance(schemaname, str):
-                schemaname = schemaname.decode(self.encoding)
         else:
             schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
-            schemaname = None
+        query = """
+            SELECT c.oid
+            FROM pg_catalog.pg_class c
+            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+            WHERE (%s)
+            AND c.relname = :table_name AND c.relkind in ('r','v')
+        """ % schema_where_clause
+        # Since we're binding to unicode, table_name and schema_name must be
+        # unicode.
+        table_name = unicode(table_name)
+        if schema is not None:
+            schema = unicode(schema)
+        s = sql.text(query, bindparams=[
+            sql.bindparam('table_name', type_=sqltypes.Unicode),
+            sql.bindparam('schema', type_=sqltypes.Unicode)
+            ],
+            typemap={'oid':sqltypes.Integer}
+        )
+        c = connection.execute(s, table_name=table_name, schema=schema)
+        table_oid = c.scalar()
+        if table_oid is None:
+            raise exc.NoSuchTableError(table_name)
+        return table_oid
+
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
+        s = """
+        SELECT nspname
+        FROM pg_namespace
+        ORDER BY nspname
+        """
+        rp = connection.execute(s)
+        # what about system tables?
+        schema_names = [row[0].decode(self.encoding) for row in rp \
+                        if not row[0].startswith('pg_')]
+        return schema_names
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        if schema is not None:
+            current_schema = schema
+        else:
+            current_schema = self.get_default_schema_name(connection)
+        table_names = self.table_names(connection, current_schema)
+        return table_names
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        if schema is not None:
+            current_schema = schema
+        else:
+            current_schema = self.get_default_schema_name(connection)
+        s = """
+        SELECT relname
+        FROM pg_class c
+        WHERE relkind = 'v'
+          AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
+        """ % dict(schema=current_schema)
+        view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
+        return view_names
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        if schema is not None:
+            current_schema = schema
+        else:
+            current_schema = self.get_default_schema_name(connection)
+        s = """
+        SELECT definition FROM pg_views
+        WHERE schemaname = :schema
+        AND viewname = :view_name
+        """
+        rp = connection.execute(sql.text(s),
+                                view_name=view_name, schema=current_schema)
+        if rp:
+            view_def = rp.scalar().decode(self.encoding)
+            return view_def
 
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
         SQL_COLS = """
             SELECT a.attname,
               pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -473,42 +677,28 @@ class PGDialect(default.DefaultDialect):
               AS DEFAULT,
               a.attnotnull, a.attnum, a.attrelid as table_oid
             FROM pg_catalog.pg_attribute a
-            WHERE a.attrelid = (
-                SELECT c.oid
-                FROM pg_catalog.pg_class c
-                     LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
-                     WHERE (%s)
-                     AND c.relname = :table_name AND c.relkind in ('r','v')
-            ) AND a.attnum > 0 AND NOT a.attisdropped
+            WHERE a.attrelid = :table_oid
+            AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
-        """ % schema_where_clause
-
-        s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
-        tablename = table.name
-        if isinstance(tablename, str):
-            tablename = tablename.decode(self.encoding)
-        c = connection.execute(s, table_name=tablename, schema=schemaname)
+        """
+        s = sql.text(SQL_COLS, 
+            bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], 
+            typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}
+        )
+        c = connection.execute(s, table_oid=table_oid)
         rows = c.fetchall()
-
-        if not rows:
-            raise exc.NoSuchTableError(table.name)
-
         domains = self._load_domains(connection)
-
+        # format columns
+        columns = []
         for name, format_type, default, notnull, attnum, table_oid in rows:
-            if include_columns and name not in include_columns:
-                continue
-
             ## strip (30) from character varying(30)
             attype = re.search('([^\([]+)', format_type).group(1)
             nullable = not notnull
             is_array = format_type.endswith('[]')
-
             try:
                 charlen = re.search('\(([\d,]+)\)', format_type).group(1)
             except:
                 charlen = False
-
             numericprec = False
             numericscale = False
             if attype == 'numeric':
@@ -523,35 +713,31 @@ class PGDialect(default.DefaultDialect):
             if attype == 'integer':
                 numericprec, numericscale = (32, 0)
                 charlen = False
-
             args = []
             for a in (charlen, numericprec, numericscale):
                 if a is None:
                     args.append(None)
                 elif a is not False:
                     args.append(int(a))
-
             kwargs = {}
             if attype == 'timestamp with time zone':
                 kwargs['timezone'] = True
             elif attype == 'timestamp without time zone':
                 kwargs['timezone'] = False
-
-            coltype = None
-            if attype in ischema_names:
-                coltype = ischema_names[attype]
+            if attype in self.ischema_names:
+                coltype = self.ischema_names[attype]
             else:
                 if attype in domains:
                     domain = domains[attype]
-                    if domain['attype'] in ischema_names:
+                    if domain['attype'] in self.ischema_names:
                         # A table can't override whether the domain is nullable.
                         nullable = domain['nullable']
-
                         if domain['default'] and not default:
                             # It can, however, override the default value, but can't set it to null.
                             default = domain['default']
-                        coltype = ischema_names[domain['attype']]
-
+                        coltype = self.ischema_names[domain['attype']]
+                else:
+                    coltype = None
             if coltype:
                 coltype = coltype(*args, **kwargs)
                 if is_array:
@@ -560,41 +746,46 @@ class PGDialect(default.DefaultDialect):
                 util.warn("Did not recognize type '%s' of column '%s'" %
                           (attype, name))
                 coltype = sqltypes.NULLTYPE
-
-            colargs = []
+            # adjust the default value
+            autoincrement = False
             if default is not None:
                 match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
                 if match is not None:
+                    autoincrement = True
                     # the default is related to a Sequence
-                    sch = table.schema
+                    sch = schema
                     if '.' not in match.group(2) and sch is not None:
                         # unconditionally quote the schema name.  this could
                         # later be enhanced to obey quoting rules / "quote schema"
                         default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
-                colargs.append(schema.DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
 
+            column_info = dict(name=name, type=coltype, nullable=nullable,
+                               default=default, autoincrement=autoincrement)
+            columns.append(column_info)
+        return columns
 
-        # Primary keys
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
         PK_SQL = """
           SELECT attname FROM pg_attribute
           WHERE attrelid = (
              SELECT indexrelid FROM pg_index i
-             WHERE i.indrelid = :table
+             WHERE i.indrelid = :table_oid
              AND i.indisprimary = 't')
           ORDER BY attnum
         """
         t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
-        c = connection.execute(t, table=table_oid)
-        for row in c.fetchall():
-            pk = row[0]
-            if pk in table.c:
-                col = table.c[pk]
-                table.primary_key.add(col)
-                if col.default is None:
-                    col.autoincrement = False
-
-        # Foreign keys
+        c = connection.execute(t, table_oid=table_oid)
+        primary_keys = [r[0] for r in c.fetchall()]
+        return primary_keys
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+        preparer = self.identifier_preparer
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
         FK_SQL = """
           SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
           FROM  pg_catalog.pg_constraint r
@@ -604,51 +795,51 @@ class PGDialect(default.DefaultDialect):
 
         t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
         c = connection.execute(t, table=table_oid)
+        fkeys = []
         for conname, condef in c.fetchall():
             m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
             (constrained_columns, referred_schema, referred_table, referred_columns) = m
             constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
             if referred_schema:
                 referred_schema = preparer._unquote_identifier(referred_schema)
-            elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
+            elif schema is not None and schema == self.get_default_schema_name(connection):
                 # no schema (i.e. its the default schema), and the table we're
                 # reflecting has the default schema explicit, then use that.
                 # i.e. try to use the user's conventions
-                referred_schema = table.schema
+                referred_schema = schema
             referred_table = preparer._unquote_identifier(referred_table)
             referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
-
-            refspec = []
-            if referred_schema is not None:
-                schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
-                            autoload_with=connection)
-                for column in referred_columns:
-                    refspec.append(".".join([referred_schema, referred_table, column]))
-            else:
-                schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
-                for column in referred_columns:
-                    refspec.append(".".join([referred_table, column]))
-
-            table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
-
-        # Indexes 
+            fkey_d = {
+                'name' : conname,
+                'constrained_columns' : constrained_columns,
+                'referred_schema' : referred_schema,
+                'referred_table' : referred_table,
+                'referred_columns' : referred_columns
+            }
+            fkeys.append(fkey_d)
+        return fkeys
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema, **kw):
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
         IDX_SQL = """
           SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
             a.attname
           FROM pg_index i, pg_class c, pg_attribute a
-          WHERE i.indrelid = :table AND i.indexrelid = c.oid
+          WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid
             AND a.attrelid = i.indexrelid AND i.indisprimary = 'f'
           ORDER BY c.relname, a.attnum
         """
         t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode})
-        c = connection.execute(t, table=table_oid)
-        indexes = {}
+        c = connection.execute(t, table_oid=table_oid)
+        index_names = {}
+        indexes = []
         sv_idx_name = None
         for row in c.fetchall():
             idx_name, unique, expr, prd, col = row
-
             if expr:
-                if not idx_name == sv_idx_name:
+                if idx_name != sv_idx_name:
                     util.warn(
                       "Skipped unsupported reflection of expression-based index %s"
                       % idx_name)
@@ -659,16 +850,16 @@ class PGDialect(default.DefaultDialect):
                    "Predicate of partial index %s ignored during reflection"
                    % idx_name)
                 sv_idx_name = idx_name
-
-            if not indexes.has_key(idx_name):
-                indexes[idx_name] = [unique, []]
-            indexes[idx_name][1].append(col)
-
-        for name, (unique, columns) in indexes.items():
-            schema.Index(name, *[table.columns[c] for c in columns], 
-                         **dict(unique=unique))
-
+            if idx_name in index_names:
+                index_d = index_names[idx_name]
+            else:
+                index_d = {'column_names':[]}
+                indexes.append(index_d)
+                index_names[idx_name] = index_d
+            index_d['name'] = idx_name
+            index_d['column_names'].append(col)
+            index_d['unique'] = unique
+        return indexes
 
     def _load_domains(self, connection):
         ## Load data types for domains:
@@ -705,185 +896,3 @@ class PGDialect(default.DefaultDialect):
 
         return domains
 
-
-class PGCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
-    operators.update(
-        {
-            sql_operators.mod : '%%',
-            sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-            sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-            sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
-        }
-    )
-
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update (
-        {
-            'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp <value>' instead")(lambda x:'TIMESTAMP %s' % x),
-        }
-    )
-
-    def visit_sequence(self, seq):
-        if seq.optional:
-            return None
-        else:
-            return "nextval('%s')" % self.preparer.format_sequence(seq)
-
-    def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.")
-        return text.replace('%', '%%')
-
-    def limit_clause(self, select):
-        text = ""
-        if select._limit is not None:
-            text +=  " \n LIMIT " + str(select._limit)
-        if select._offset is not None:
-            if select._limit is None:
-                text += " \n LIMIT ALL"
-            text += " OFFSET " + str(select._offset)
-        return text
-
-    def get_select_precolumns(self, select):
-        if select._distinct:
-            if isinstance(select._distinct, bool):
-                return "DISTINCT "
-            elif isinstance(select._distinct, (list, tuple)):
-                return "DISTINCT ON (" + ', '.join(
-                    [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
-                )+ ") "
-            else:
-                return "DISTINCT ON (" + unicode(select._distinct) + ") "
-        else:
-            return ""
-
-    def for_update_clause(self, select):
-        if select.for_update == 'nowait':
-            return " FOR UPDATE NOWAIT"
-        else:
-            return super(PGCompiler, self).for_update_clause(select)
-
-    def _append_returning(self, text, stmt):
-        returning_cols = stmt.kwargs['postgres_returning']
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-        columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
-        text += ' RETURNING ' + string.join(columns, ', ')
-        return text
-
-    def visit_update(self, update_stmt):
-        text = super(PGCompiler, self).visit_update(update_stmt)
-        if 'postgres_returning' in update_stmt.kwargs:
-            return self._append_returning(text, update_stmt)
-        else:
-            return text
-
-    def visit_insert(self, insert_stmt):
-        text = super(PGCompiler, self).visit_insert(insert_stmt)
-        if 'postgres_returning' in insert_stmt.kwargs:
-            return self._append_returning(text, insert_stmt)
-        else:
-            return text
-
-    def visit_extract(self, extract, **kwargs):
-        field = self.extract_map.get(extract.field, extract.field)
-        return "EXTRACT(%s FROM %s::timestamp)" % (
-            field, self.process(extract.expr))
-
-
-class PGSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
-            if isinstance(column.type, PGBigInteger):
-                colspec += " BIGSERIAL"
-            else:
-                colspec += " SERIAL"
-        else:
-            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-            default = self.get_column_default_string(column)
-            if default is not None:
-                colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
-
-    def visit_sequence(self, sequence):
-        if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
-            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-    def visit_index(self, index):
-        preparer = self.preparer
-        self.append("CREATE ")
-        if index.unique:
-            self.append("UNIQUE ")
-        self.append("INDEX %s ON %s (%s)" \
-                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
-                       preparer.format_table(index.table),
-                       string.join([preparer.format_column(c) for c in index.columns], ', ')))
-        whereclause = index.kwargs.get('postgres_where', None)
-        if whereclause is not None:
-            compiler = self._compile(whereclause, None)
-            # this might belong to the compiler class
-            inlined_clause = str(compiler) % dict(
-                [(key,bind.value) for key,bind in compiler.binds.iteritems()])
-            self.append(" WHERE " + inlined_clause)
-        self.execute()
-
-class PGSchemaDropper(compiler.SchemaDropper):
-    def visit_sequence(self, sequence):
-        if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
-            self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-class PGDefaultRunner(base.DefaultRunner):
-    def __init__(self, context):
-        base.DefaultRunner.__init__(self, context)
-        # craete cursor which won't conflict with a server-side cursor
-        self.cursor = context._connection.connection.cursor()
-    
-    def get_column_default(self, column, isinsert=True):
-        if column.primary_key:
-            # pre-execute passive defaults on primary keys
-            if (isinstance(column.server_default, schema.DefaultClause) and
-                column.server_default.arg is not None):
-                return self.execute_string("select %s" % column.server_default.arg)
-            elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
-                sch = column.table.schema
-                # TODO: this has to build into the Sequence object so we can get the quoting
-                # logic from it
-                if sch is not None:
-                    exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
-                else:
-                    exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
-                return self.execute_string(exc.encode(self.dialect.encoding))
-
-        return super(PGDefaultRunner, self).get_column_default(column)
-
-    def visit_sequence(self, seq):
-        if not seq.optional:
-            return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
-        else:
-            return None
-
-class PGIdentifierPreparer(compiler.IdentifierPreparer):
-    def _unquote_identifier(self, value):
-        if value[0] == self.initial_quote:
-            value = value[1:-1].replace('""','"')
-        return value
-
-dialect = PGDialect
-dialect.statement_compiler = PGCompiler
-dialect.schemagenerator = PGSchemaGenerator
-dialect.schemadropper = PGSchemaDropper
-dialect.preparer = PGIdentifierPreparer
-dialect.defaultrunner = PGDefaultRunner
-dialect.execution_ctx_cls = PGExecutionContext
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
new file mode 100644 (file)
index 0000000..e8dd031
--- /dev/null
@@ -0,0 +1,84 @@
+"""Support for the PostgreSQL database via the pg8000 driver.
+
+Connecting
+----------
+
+URLs are of the form `postgresql+pg8000://user@password@host:port/dbname[?key=value&key=value...]`.
+
+Unicode
+-------
+
+pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file
+in order to use encodings other than ascii.  Set this value to the same value as 
+the "encoding" parameter on create_engine(), usually "utf-8".
+
+Interval
+--------
+
+Passing data from/to the Interval type is not supported as of yet.
+
+"""
+from sqlalchemy.engine import default
+import decimal
+from sqlalchemy import util
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler
+
+class _PGNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
+        else:
+            def process(value):
+                if isinstance(value, decimal.Decimal):
+                    return float(value)
+                else:
+                    return value
+            return process
+
+class PostgreSQL_pg8000ExecutionContext(default.DefaultExecutionContext):
+    pass
+
+class PostgreSQL_pg8000Compiler(PGCompiler):
+    def visit_mod(self, binary, **kw):
+        return self.process(binary.left) + " %% " + self.process(binary.right)
+    
+    
+class PostgreSQL_pg8000(PGDialect):
+    driver = 'pg8000'
+
+    supports_unicode_statements = True
+    
+    supports_unicode_binds = True
+    
+    default_paramstyle = 'format'
+    supports_sane_multi_rowcount = False
+    execution_ctx_cls = PostgreSQL_pg8000ExecutionContext
+    statement_compiler = PostgreSQL_pg8000Compiler
+    
+    colspecs = util.update_copy(
+        PGDialect.colspecs,
+        {
+            sqltypes.Numeric : _PGNumeric,
+            sqltypes.Float: sqltypes.Float,  # prevents _PGNumeric from being used
+        }
+    )
+    
+    @classmethod
+    def dbapi(cls):
+        return __import__('pg8000').dbapi
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if 'port' in opts:
+            opts['port'] = int(opts['port'])
+        opts.update(url.query)
+        return ([], opts)
+
+    def is_disconnect(self, e):
+        return "connection is closed" in str(e)
+
+dialect = PostgreSQL_pg8000
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
new file mode 100644 (file)
index 0000000..a428878
--- /dev/null
@@ -0,0 +1,147 @@
+"""Support for the PostgreSQL database via the psycopg2 driver.
+
+Driver
+------
+
+The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
+The dialect has several behaviors  which are specifically tailored towards compatibility 
+with this module.
+
+Note that psycopg1 is **not** supported.
+
+Connecting
+----------
+
+URLs are of the form `postgresql+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`.
+
+psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
+
+* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
+  this feature.  What this essentially means from a psycopg2 point of view is that the cursor is 
+  created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
+  are not immediately pre-fetched and buffered after statement execution, but are instead left 
+  on the server and only retrieved as needed.    SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
+  uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows 
+  at a time are fetched over the wire to reduce conversational overhead.
+
+* *isolation_level* - Sets the transaction isolation level for each transaction
+  within the engine. Valid isolation levels are `READ_COMMITTED`,
+  `READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`.
+
+Transactions
+------------
+
+The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+
+
+"""
+
+import decimal, random, re
+from sqlalchemy import util
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler
+
+class _PGNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
+        else:
+            def process(value):
+                if isinstance(value, decimal.Decimal):
+                    return float(value)
+                else:
+                    return value
+            return process
+
+
+# TODO: filter out 'FOR UPDATE' statements
+SERVER_SIDE_CURSOR_RE = re.compile(
+    r'\s*SELECT',
+    re.I | re.UNICODE)
+
+class PostgreSQL_psycopg2ExecutionContext(default.DefaultExecutionContext):
+    def create_cursor(self):
+        # TODO: coverage for server side cursors + select.for_update()
+        is_server_side = \
+            self.dialect.server_side_cursors and \
+            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) 
+                and not getattr(self.compiled.statement, 'for_update', False)) \
+            or \
+            (
+                (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) 
+                and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
+            )
+
+        self.__is_server_side = is_server_side
+        if is_server_side:
+            # use server-side cursors:
+            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
+            return self._connection.connection.cursor(ident)
+        else:
+            return self._connection.connection.cursor()
+
+    def get_result_proxy(self):
+        if self.__is_server_side:
+            return base.BufferedRowResultProxy(self)
+        else:
+            return base.ResultProxy(self)
+
+class PostgreSQL_psycopg2Compiler(PGCompiler):
+    def visit_mod(self, binary, **kw):
+        return self.process(binary.left) + " %% " + self.process(binary.right)
+    
+    def post_process_text(self, text):
+        return text.replace('%', '%%')
+
+class PostgreSQL_psycopg2(PGDialect):
+    driver = 'psycopg2'
+    supports_unicode_statements = False
+    default_paramstyle = 'pyformat'
+    supports_sane_multi_rowcount = False
+    execution_ctx_cls = PostgreSQL_psycopg2ExecutionContext
+    statement_compiler = PostgreSQL_psycopg2Compiler
+
+    colspecs = util.update_copy(
+        PGDialect.colspecs,
+        {
+            sqltypes.Numeric : _PGNumeric,
+            sqltypes.Float: sqltypes.Float,  # prevents _PGNumeric from being used
+        }
+    )
+
+    def __init__(self, server_side_cursors=False, **kwargs):
+        PGDialect.__init__(self, **kwargs)
+        self.server_side_cursors = server_side_cursors
+
+    @classmethod
+    def dbapi(cls):
+        psycopg = __import__('psycopg2')
+        return psycopg
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if 'port' in opts:
+            opts['port'] = int(opts['port'])
+        opts.update(url.query)
+        return ([], opts)
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.OperationalError):
+            return 'closed the connection' in str(e) or 'connection not open' in str(e)
+        elif isinstance(e, self.dbapi.InterfaceError):
+            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
+        elif isinstance(e, self.dbapi.ProgrammingError):
+            # yes, it really says "losed", not "closed"
+            return "losed the connection unexpectedly" in str(e)
+        else:
+            return False
+
+dialect = PostgreSQL_psycopg2
+    
diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
new file mode 100644 (file)
index 0000000..975006d
--- /dev/null
@@ -0,0 +1,80 @@
+"""Support for the PostgreSQL database via py-postgresql.
+
+Connecting
+----------
+
+URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`.
+
+
+"""
+from sqlalchemy.engine import default
+import decimal
+from sqlalchemy import util
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGDefaultRunner
+
+class PGNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
+        else:
+            def process(value):
+                if isinstance(value, decimal.Decimal):
+                    return float(value)
+                else:
+                    return value
+            return process
+
+class PostgreSQL_pypostgresqlExecutionContext(default.DefaultExecutionContext):
+    pass
+
+class PostgreSQL_pypostgresqlDefaultRunner(PGDefaultRunner):
+    def execute_string(self, stmt, params=None):
+        return PGDefaultRunner.execute_string(self, stmt, params or ())
+        
+class PostgreSQL_pypostgresql(PGDialect):
+    driver = 'pypostgresql'
+
+    supports_unicode_statements = True
+    
+    supports_unicode_binds = True
+    description_encoding = None
+    
+    defaultrunner = PostgreSQL_pypostgresqlDefaultRunner
+    
+    default_paramstyle = 'format'
+    
+    supports_sane_rowcount = False  # alas....posting a bug now
+    
+    supports_sane_multi_rowcount = False
+    
+    execution_ctx_cls = PostgreSQL_pypostgresqlExecutionContext
+    colspecs = util.update_copy(
+        PGDialect.colspecs,
+        {
+            sqltypes.Numeric : PGNumeric,
+            sqltypes.Float: sqltypes.Float,  # prevents PGNumeric from being used
+        }
+    )
+    
+    @classmethod
+    def dbapi(cls):
+        from postgresql.driver import dbapi20
+        return dbapi20
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if 'port' in opts:
+            opts['port'] = int(opts['port'])
+        else:
+            opts['port'] = 5432
+        opts.update(url.query)
+        return ([], opts)
+
+    def is_disconnect(self, e):
+        return "connection is closed" in str(e)
+
+dialect = PostgreSQL_pypostgresql
diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py
new file mode 100644 (file)
index 0000000..b707d2d
--- /dev/null
@@ -0,0 +1,28 @@
+"""Support for the PostgreSQL database via the zxjdbc JDBC connector.
+
+JDBC Driver
+-----------
+
+The official Postgresql JDBC driver is at http://jdbc.postgresql.org/.
+
+"""
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.postgresql.base import PGCompiler, PGDialect
+
+class PostgreSQL_jdbcCompiler(PGCompiler):
+
+    def post_process_text(self, text):
+        # Don't escape '%' like PGCompiler
+        return text
+
+
+class PostgreSQL_jdbc(ZxJDBCConnector, PGDialect):
+    statement_compiler = PostgreSQL_jdbcCompiler
+
+    jdbc_db_name = 'postgresql'
+    jdbc_driver_name = 'org.postgresql.Driver'
+
+    def _get_server_version_info(self, connection):
+        return tuple(int(x) for x in connection.connection.dbversion.split('.'))
+
+dialect = PostgreSQL_jdbc
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
new file mode 100644 (file)
index 0000000..3cc0887
--- /dev/null
@@ -0,0 +1,4 @@
+from sqlalchemy.dialects.sqlite import base, pysqlite
+
+# default dialect
+base.dialect = pysqlite.dialect
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
new file mode 100644 (file)
index 0000000..8dea91d
--- /dev/null
@@ -0,0 +1,526 @@
+# sqlite.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Support for the SQLite database.
+
+For information on connecting using a specific driver, see the documentation
+section regarding that driver.
+
+Date and Time Types
+-------------------
+
+SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide 
+out of the box functionality for translating values between Python `datetime` objects
+and a SQLite-supported format.  SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
+and related types provide date formatting and parsing functionality when SQlite is used.
+The implementation classes are :class:`_SLDateTime`, :class:`_SLDate` and :class:`_SLTime`.
+These types represent dates and times as ISO formatted strings, which also nicely
+support ordering.   There's no reliance on typical "libc" internals for these functions
+so historical dates are fully supported.
+
+
+"""
+
+import datetime, re, time
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import sql, exc, pool, DefaultClause
+from sqlalchemy.engine import default
+from sqlalchemy.engine import reflection
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.sql import compiler, functions as sql_functions
+from sqlalchemy.util import NoneType
+
+from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
+                            FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\
+                            TIMESTAMP, VARCHAR
+                            
+
+class _NumericMixin(object):
+    def bind_processor(self, dialect):
+        type_ = self.asdecimal and str or float
+        def process(value):
+            if value is not None:
+                return type_(value)
+            else:
+                return value
+        return process
+
+class _SLNumeric(_NumericMixin, sqltypes.Numeric):
+    pass
+
+class _SLFloat(_NumericMixin, sqltypes.Float):
+    pass
+
+# since SQLite has no date types, we're assuming that SQLite via ODBC
+# or JDBC would similarly have no built in date support, so the "string" based logic
+# would apply to all implementing dialects.
+class _DateTimeMixin(object):
+    def _bind_processor(self, format, elements):
+        def process(value):
+            if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
+                raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
+            elif value is not None:
+                return format % tuple([getattr(value, attr, 0) for attr in elements])
+            else:
+                return None
+        return process
+
+    def _result_processor(self, fn, regexp):
+        def process(value):
+            if value is not None:
+                return fn(*[int(x or 0) for x in regexp.match(value).groups()])
+            else:
+                return None
+        return process
+
+class _SLDateTime(_DateTimeMixin, sqltypes.DateTime):
+    __legacy_microseconds__ = False
+
+    def bind_processor(self, dialect):
+        if self.__legacy_microseconds__:
+            return self._bind_processor(
+                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", 
+                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
+                        )
+        else:
+            return self._bind_processor(
+                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", 
+                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
+                        )
+
+    _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
+    def result_processor(self, dialect):
+        return self._result_processor(datetime.datetime, self._reg)
+
+class _SLDate(_DateTimeMixin, sqltypes.Date):
+    def bind_processor(self, dialect):
+        return self._bind_processor(
+                        "%4.4d-%2.2d-%2.2d", 
+                        ("year", "month", "day")
+                )
+
+    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+    def result_processor(self, dialect):
+        return self._result_processor(datetime.date, self._reg)
+
+class _SLTime(_DateTimeMixin, sqltypes.Time):
+    __legacy_microseconds__ = False
+
+    def bind_processor(self, dialect):
+        if self.__legacy_microseconds__:
+            return self._bind_processor(
+                            "%2.2d:%2.2d:%2.2d.%s", 
+                            ("hour", "minute", "second", "microsecond")
+                    )
+        else:
+            return self._bind_processor(
+                            "%2.2d:%2.2d:%2.2d.%06d", 
+                            ("hour", "minute", "second", "microsecond")
+                    )
+
+    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+    def result_processor(self, dialect):
+        return self._result_processor(datetime.time, self._reg)
+
+
+class _SLBoolean(sqltypes.Boolean):
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and 1 or 0
+        return process
+
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value == 1
+        return process
+
+colspecs = {
+    sqltypes.Boolean: _SLBoolean,
+    sqltypes.Date: _SLDate,
+    sqltypes.DateTime: _SLDateTime,
+    sqltypes.Float: _SLFloat,
+    sqltypes.Numeric: _SLNumeric,
+    sqltypes.Time: _SLTime,
+}
+
+ischema_names = {
+    'BLOB': sqltypes.BLOB,
+    'BOOL': sqltypes.BOOLEAN,
+    'BOOLEAN': sqltypes.BOOLEAN,
+    'CHAR': sqltypes.CHAR,
+    'DATE': sqltypes.DATE,
+    'DATETIME': sqltypes.DATETIME,
+    'DECIMAL': sqltypes.DECIMAL,
+    'FLOAT': sqltypes.FLOAT,
+    'INT': sqltypes.INTEGER,
+    'INTEGER': sqltypes.INTEGER,
+    'NUMERIC': sqltypes.NUMERIC,
+    'REAL': sqltypes.Numeric,
+    'SMALLINT': sqltypes.SMALLINT,
+    'TEXT': sqltypes.TEXT,
+    'TIME': sqltypes.TIME,
+    'TIMESTAMP': sqltypes.TIMESTAMP,
+    'VARCHAR': sqltypes.VARCHAR,
+}
+
+
+
+class SQLiteCompiler(compiler.SQLCompiler):
+    extract_map = compiler.SQLCompiler.extract_map.copy()
+    extract_map.update({
+        'month': '%m',
+        'day': '%d',
+        'year': '%Y',
+        'second': '%S',
+        'hour': '%H',
+        'doy': '%j',
+        'minute': '%M',
+        'epoch': '%s',
+        'dow': '%w',
+        'week': '%W'
+    })
+
+    def visit_now_func(self, fn, **kw):
+        return "CURRENT_TIMESTAMP"
+    
+    def visit_char_length_func(self, fn, **kw):
+        return "length%s" % self.function_argspec(fn)
+        
+    def visit_cast(self, cast, **kwargs):
+        if self.dialect.supports_cast:
+            return super(SQLiteCompiler, self).visit_cast(cast)
+        else:
+            return self.process(cast.clause)
+
+    def visit_extract(self, extract):
+        try:
+            return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
+                self.extract_map[extract.field], self.process(extract.expr))
+        except KeyError:
+            raise exc.ArgumentError(
+                "%s is not a valid extract argument." % extract.field)
+
+    def limit_clause(self, select):
+        text = ""
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
+                text += " \n LIMIT -1"
+            text += " OFFSET " + str(select._offset)
+        else:
+            text += " OFFSET 0"
+        return text
+
+    def for_update_clause(self, select):
+        # sqlite has no "FOR UPDATE" AFAICT
+        return ''
+
+
+class SQLiteDDLCompiler(compiler.DDLCompiler):
+
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
+
+class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_binary(self, type_):
+        return self.visit_BLOB(type_)
+    
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = set([
+        'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
+        'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
+        'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
+        'conflict', 'constraint', 'create', 'cross', 'current_date',
+        'current_time', 'current_timestamp', 'database', 'default',
+        'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
+        'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
+        'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
+        'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
+        'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
+        'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
+        'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
+        'plan', 'pragma', 'primary', 'query', 'raise', 'references',
+        'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
+        'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
+        'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
+        'vacuum', 'values', 'view', 'virtual', 'when', 'where',
+        ])
+
+class SQLiteDialect(default.DefaultDialect):
+    name = 'sqlite'
+    supports_alter = False
+    supports_unicode_statements = True
+    supports_unicode_binds = True
+    supports_default_values = True
+    supports_empty_insert = False
+    supports_cast = True
+
+    default_paramstyle = 'qmark'
+    statement_compiler = SQLiteCompiler
+    ddl_compiler = SQLiteDDLCompiler
+    type_compiler = SQLiteTypeCompiler
+    preparer = SQLiteIdentifierPreparer
+    ischema_names = ischema_names
+    colspecs = colspecs
+    isolation_level = None
+
+    def __init__(self, isolation_level=None, **kwargs):
+        default.DefaultDialect.__init__(self, **kwargs)
+        if isolation_level and isolation_level not in ('SERIALIZABLE',
+                'READ UNCOMMITTED'):
+            raise exc.ArgumentError("Invalid value for isolation_level. "
+                "Valid isolation levels for sqlite are 'SERIALIZABLE' and "
+                "'READ UNCOMMITTED'.")
+        self.isolation_level = isolation_level
+
+    def visit_pool(self, pool):
+        if self.isolation_level is not None:
+            class SetIsolationLevel(object):
+                def __init__(self, isolation_level):
+                    if isolation_level == 'READ UNCOMMITTED':
+                        self.isolation_level = 1
+                    else:
+                        self.isolation_level = 0
+
+                def connect(self, conn, rec):
+                    cursor = conn.cursor()
+                    cursor.execute("PRAGMA read_uncommitted = %d" % self.isolation_level)
+                    cursor.close()
+            pool.add_listener(SetIsolationLevel(self.isolation_level))
+
+    def table_names(self, connection, schema):
+        if schema is not None:
+            qschema = self.identifier_preparer.quote_identifier(schema)
+            master = '%s.sqlite_master' % qschema
+            s = ("SELECT name FROM %s "
+                 "WHERE type='table' ORDER BY name") % (master,)
+            rs = connection.execute(s)
+        else:
+            try:
+                s = ("SELECT name FROM "
+                     " (SELECT * FROM sqlite_master UNION ALL "
+                     "  SELECT * FROM sqlite_temp_master) "
+                     "WHERE type='table' ORDER BY name")
+                rs = connection.execute(s)
+            except exc.DBAPIError:
+                raise
+                s = ("SELECT name FROM sqlite_master "
+                     "WHERE type='table' ORDER BY name")
+                rs = connection.execute(s)
+
+        return [row[0] for row in rs]
+
+    def has_table(self, connection, table_name, schema=None):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
+        else:
+            pragma = "PRAGMA "
+        qtable = quote(table_name)
+        cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
+        row = cursor.fetchone()
+
+        # consume remaining rows, to work around
+        # http://www.sqlite.org/cvstrac/tktview?tn=1884
+        while cursor.fetchone() is not None:
+            pass
+
+        return (row is not None)
+
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        return self.table_names(connection, schema)
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        if schema is not None:
+            qschema = self.identifier_preparer.quote_identifier(schema)
+            master = '%s.sqlite_master' % qschema
+            s = ("SELECT name FROM %s "
+                 "WHERE type='view' ORDER BY name") % (master,)
+            rs = connection.execute(s)
+        else:
+            try:
+                s = ("SELECT name FROM "
+                     " (SELECT * FROM sqlite_master UNION ALL "
+                     "  SELECT * FROM sqlite_temp_master) "
+                     "WHERE type='view' ORDER BY name")
+                rs = connection.execute(s)
+            except exc.DBAPIError:
+                raise
+                s = ("SELECT name FROM sqlite_master "
+                     "WHERE type='view' ORDER BY name")
+                rs = connection.execute(s)
+
+        return [row[0] for row in rs]
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            qschema = self.identifier_preparer.quote_identifier(schema)
+            master = '%s.sqlite_master' % qschema
+            s = ("SELECT sql FROM %s WHERE name = '%s'"
+                 "AND type='view'") % (master, view_name)
+            rs = connection.execute(s)
+        else:
+            try:
+                s = ("SELECT sql FROM "
+                     " (SELECT * FROM sqlite_master UNION ALL "
+                     "  SELECT * FROM sqlite_temp_master) "
+                     "WHERE name = '%s' "
+                     "AND type='view'") % view_name
+                rs = connection.execute(s)
+            except exc.DBAPIError:
+                raise
+                s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
+                     "AND type='view'") % view_name
+                rs = connection.execute(s)
+
+        result = rs.fetchall()
+        if result:
+            return result[0].sql
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
+        else:
+            pragma = "PRAGMA "
+        qtable = quote(table_name)
+        c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
+        found_table = False
+        columns = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
+            name = re.sub(r'^\"|\"$', '', name)
+            if default:
+                default = re.sub(r"^\'|\'$", '', default)
+            match = re.match(r'(\w+)(\(.*?\))?', type_)
+            if match:
+                coltype = match.group(1)
+                args = match.group(2)
+            else:
+                coltype = "VARCHAR"
+                args = ''
+            try:
+                coltype = self.ischema_names[coltype]
+            except KeyError:
+                util.warn("Did not recognize type '%s' of column '%s'" %
+                          (coltype, name))
+                coltype = sqltypes.NullType
+            if args is not None:
+                args = re.findall(r'(\d+)', args)
+                coltype = coltype(*[int(a) for a in args])
+
+            columns.append({
+                'name' : name,
+                'type' : coltype,
+                'nullable' : nullable,
+                'default' : default,
+                'primary_key': primary_key
+            })
+        return columns
+
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        cols = self.get_columns(connection, table_name, schema, **kw)
+        pkeys = []
+        for col in cols:
+            if col['primary_key']:
+                pkeys.append(col['name'])
+        return pkeys
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
+        else:
+            pragma = "PRAGMA "
+        qtable = quote(table_name)
+        c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)))
+        fkeys = []
+        fks = {}
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            (constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
+            rtbl = re.sub(r'^\"|\"$', '', rtbl)
+            lcol = re.sub(r'^\"|\"$', '', lcol)
+            rcol = re.sub(r'^\"|\"$', '', rcol)
+            try:
+                fk = fks[constraint_name]
+            except KeyError:
+                fk = {
+                    'name' : constraint_name,
+                    'constrained_columns' : [],
+                    'referred_schema' : None,
+                    'referred_table' : rtbl,
+                    'referred_columns' : []
+                }
+                fkeys.append(fk)
+                fks[constraint_name] = fk
+
+            # look up the table based on the given table's engine, not 'self',
+            # since it could be a ProxyEngine
+            if lcol not in fk['constrained_columns']:
+                fk['constrained_columns'].append(lcol)
+            if rcol not in fk['referred_columns']:
+                fk['referred_columns'].append(rcol)
+        return fkeys
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None, **kw):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
+        else:
+            pragma = "PRAGMA "
+        qtable = quote(table_name)
+        c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
+        indexes = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
+        # loop thru unique indexes to get the column names.
+        for idx in indexes:
+            c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name'])))
+            cols = idx['column_names']
+            while True:
+                row = c.fetchone()
+                if row is None:
+                    break
+                cols.append(row[2])
+        return indexes
+
+
+def _pragma_cursor(cursor):
+    """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows."""
+    
+    if cursor.closed:
+        cursor._fetchone_impl = lambda: None
+    return cursor
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
new file mode 100644 (file)
index 0000000..a1873f3
--- /dev/null
@@ -0,0 +1,174 @@
+"""Support for the SQLite database via pysqlite.
+
+Note that pysqlite is the same driver as the ``sqlite3``
+module included with the Python distribution.
+
+Driver
+------
+
+When using Python 2.5 and above, the built in ``sqlite3`` driver is 
+already installed and no additional installation is needed.  Otherwise,
+the ``pysqlite2`` driver needs to be present.  This is the same driver as
+``sqlite3``, just with a different name.
+
+The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
+is loaded.  This allows an explicitly installed pysqlite driver to take
+precedence over the built in one.   As with all dialects, a specific 
+DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control 
+this explicitly::
+
+    from sqlite3 import dbapi2 as sqlite
+    e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
+
+Full documentation on pysqlite is available at:
+`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
+
+Connect Strings
+---------------
+
+The file specification for the SQLite database is taken as the "database" portion of
+the URL.  Note that the format of a url is::
+
+    driver://user:pass@host/database
+    
+This means that the actual filename to be used starts with the characters to the
+**right** of the third slash.   So connecting to a relative filepath looks like::
+
+    # relative path
+    e = create_engine('sqlite:///path/to/database.db')
+    
+An absolute path, which is denoted by starting with a slash, means you need **four**
+slashes::
+
+    # absolute path
+    e = create_engine('sqlite:////path/to/database.db')
+
+To use a Windows path, regular drive specifications and backslashes can be used.  
+Double backslashes are probably needed::
+
+    # absolute path on Windows
+    e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
+
+The sqlite ``:memory:`` identifier is the default if no filepath is present.  Specify
+``sqlite://`` and nothing else::
+
+    # in-memory database
+    e = create_engine('sqlite://')
+
+Threading Behavior
+------------------
+
+Pysqlite connections do not support being moved between threads, unless
+the ``check_same_thread`` Pysqlite flag is set to ``False``.  In addition,
+when using an in-memory SQLite database, the full database exists only within 
+the scope of a single connection.  It is reported that an in-memory
+database does not support being shared between threads regardless of the 
+``check_same_thread`` flag - which means that a multithreaded
+application **cannot** share data from a ``:memory:`` database across threads
+unless access to the connection is limited to a single worker thread which communicates
+through a queueing mechanism to concurrent threads.
+
+To provide a default which accomodates SQLite's default threading capabilities
+somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
+be used by default.  This pool maintains a single SQLite connection per thread
+that is held open up to a count of five concurrent threads.  When more than five threads
+are used, a cleanup mechanism will dispose of excess unused connections.   
+
+Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
+
+ * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
+   application using an in-memory database, assuming the threading issues inherent in 
+   pysqlite are somehow accomodated for.  This pool holds persistently onto a single connection
+   which is never closed, and is returned for all requests.
+   
+ * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
+   makes use of a file-based sqlite database.  This pool disables any actual "pooling"
+   behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
+   and :func:`close()` methods.  SQLite can "connect" to a particular file with very high 
+   efficiency, so this option may actually perform better without the extra overhead
+   of :class:`SingletonThreadPool`.  NullPool will of course render a ``:memory:`` connection
+   useless since the database would be lost as soon as the connection is "returned" to the pool.
+
+Unicode
+-------
+
+In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's 
+default behavior regarding Unicode is that all strings are returned as Python unicode objects
+in all cases.  So even if the :class:`~sqlalchemy.types.Unicode` type is 
+*not* used, you will still always receive unicode data back from a result set.  It is 
+**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
+to represent strings, since it will raise a warning if a non-unicode Python string is 
+passed from the user application.  Mixing the usage of non-unicode objects with returned unicode objects can
+quickly create confusion, particularly when using the ORM as internal data is not 
+always represented by an actual database result string.
+
+"""
+
+from sqlalchemy.dialects.sqlite.base import SQLiteDialect
+from sqlalchemy import schema, exc, pool
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+
+class SQLite_pysqlite(SQLiteDialect):
+    default_paramstyle = 'qmark'
+    poolclass = pool.SingletonThreadPool
+    
+    # Py3K
+    #description_encoding = None
+    
+    driver = 'pysqlite'
+    
+    def __init__(self, **kwargs):
+        SQLiteDialect.__init__(self, **kwargs)
+        def vers(num):
+            return tuple([int(x) for x in num.split('.')])
+        if self.dbapi is not None:
+            sqlite_ver = self.dbapi.version_info
+            if sqlite_ver < (2, 1, '3'):
+                util.warn(
+                    ("The installed version of pysqlite2 (%s) is out-dated "
+                     "and will cause errors in some cases.  Version 2.1.3 "
+                     "or greater is recommended.") %
+                    '.'.join([str(subver) for subver in sqlite_ver]))
+            if self.dbapi.sqlite_version_info < (3, 3, 8):
+                self.supports_default_values = False
+        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
+
+    @classmethod
+    def dbapi(cls):
+        try:
+            from pysqlite2 import dbapi2 as sqlite
+        except ImportError, e:
+            try:
+                from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+            except ImportError:
+                raise e
+        return sqlite
+
+    def _get_server_version_info(self, connection):
+        return self.dbapi.sqlite_version_info
+
+    def create_connect_args(self, url):
+        if url.username or url.password or url.host or url.port:
+            raise exc.ArgumentError(
+                "Invalid SQLite URL: %s\n"
+                "Valid SQLite URL forms are:\n"
+                " sqlite:///:memory: (or, sqlite://)\n"
+                " sqlite:///relative/path/to/file.db\n"
+                " sqlite:////absolute/path/to/file.db" % (url,))
+        filename = url.database or ':memory:'
+
+        opts = url.query.copy()
+        util.coerce_kw_type(opts, 'timeout', float)
+        util.coerce_kw_type(opts, 'isolation_level', str)
+        util.coerce_kw_type(opts, 'detect_types', int)
+        util.coerce_kw_type(opts, 'check_same_thread', bool)
+        util.coerce_kw_type(opts, 'cached_statements', int)
+
+        return ([filename], opts)
+
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
+
+dialect = SQLite_pysqlite
diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py
new file mode 100644 (file)
index 0000000..f8baf33
--- /dev/null
@@ -0,0 +1,4 @@
+from sqlalchemy.dialects.sybase import base, pyodbc
+
+# default dialect
+base.dialect = pyodbc.dialect
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
new file mode 100644 (file)
index 0000000..6f8c648
--- /dev/null
@@ -0,0 +1,458 @@
+# sybase.py
+# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
+# Coding: Alexander Houben alexander.houben@thor-solutions.ch
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Support for the Sybase iAnywhere database.  
+
+This is not a full backend for Sybase ASE.
+
+This dialect is *not* tested on SQLAlchemy 0.6.
+
+
+Known issues / TODO:
+
+ * Uses the mx.ODBC driver from egenix (version 2.1.0)
+ * The current version of sqlalchemy.databases.sybase only supports
+   mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
+   some development)
+ * Support for pyodbc has been built in but is not yet complete (needs
+   further development)
+ * Results of running tests/alltests.py:
+     Ran 934 tests in 287.032s
+     FAILED (failures=3, errors=1)
+ * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
+"""
+
+import datetime, operator
+
+from sqlalchemy import util, sql, schema, exc
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import MetaData, Table, Column
+from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
+from sqlalchemy.dialects.sybase.schema import *
+
+RESERVED_WORDS = set([
+    "add", "all", "alter", "and",
+    "any", "as", "asc", "backup",
+    "begin", "between", "bigint", "binary",
+    "bit", "bottom", "break", "by",
+    "call", "capability", "cascade", "case",
+    "cast", "char", "char_convert", "character",
+    "check", "checkpoint", "close", "comment",
+    "commit", "connect", "constraint", "contains",
+    "continue", "convert", "create", "cross",
+    "cube", "current", "current_timestamp", "current_user",
+    "cursor", "date", "dbspace", "deallocate",
+    "dec", "decimal", "declare", "default",
+    "delete", "deleting", "desc", "distinct",
+    "do", "double", "drop", "dynamic",
+    "else", "elseif", "encrypted", "end",
+    "endif", "escape", "except", "exception",
+    "exec", "execute", "existing", "exists",
+    "externlogin", "fetch", "first", "float",
+    "for", "force", "foreign", "forward",
+    "from", "full", "goto", "grant",
+    "group", "having", "holdlock", "identified",
+    "if", "in", "index", "index_lparen",
+    "inner", "inout", "insensitive", "insert",
+    "inserting", "install", "instead", "int",
+    "integer", "integrated", "intersect", "into",
+    "iq", "is", "isolation", "join",
+    "key", "lateral", "left", "like",
+    "lock", "login", "long", "match",
+    "membership", "message", "mode", "modify",
+    "natural", "new", "no", "noholdlock",
+    "not", "notify", "null", "numeric",
+    "of", "off", "on", "open",
+    "option", "options", "or", "order",
+    "others", "out", "outer", "over",
+    "passthrough", "precision", "prepare", "primary",
+    "print", "privileges", "proc", "procedure",
+    "publication", "raiserror", "readtext", "real",
+    "reference", "references", "release", "remote",
+    "remove", "rename", "reorganize", "resource",
+    "restore", "restrict", "return", "revoke",
+    "right", "rollback", "rollup", "save",
+    "savepoint", "scroll", "select", "sensitive",
+    "session", "set", "setuser", "share",
+    "smallint", "some", "sqlcode", "sqlstate",
+    "start", "stop", "subtrans", "subtransaction",
+    "synchronize", "syntax_error", "table", "temporary",
+    "then", "time", "timestamp", "tinyint",
+    "to", "top", "tran", "trigger",
+    "truncate", "tsequal", "unbounded", "union",
+    "unique", "unknown", "unsigned", "update",
+    "updating", "user", "using", "validate",
+    "values", "varbinary", "varchar", "variable",
+    "varying", "view", "wait", "waitfor",
+    "when", "where", "while", "window",
+    "with", "with_cube", "with_lparen", "with_rollup",
+    "within", "work", "writetext",
+    ])
+
+
+class SybaseImage(sqltypes.Binary):
+    __visit_name__ = 'IMAGE'
+
+class SybaseBit(sqltypes.TypeEngine):
+    __visit_name__ = 'BIT'
+    
+class SybaseMoney(sqltypes.TypeEngine):
+    __visit_name__ = "MONEY"
+
+class SybaseSmallMoney(SybaseMoney):
+    __visit_name__ = "SMALLMONEY"
+
+class SybaseUniqueIdentifier(sqltypes.TypeEngine):
+    __visit_name__ = "UNIQUEIDENTIFIER"
+    
+class SybaseBoolean(sqltypes.Boolean):
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+
+class SybaseTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_binary(self, type_):
+        return self.visit_IMAGE(type_)
+    
+    def visit_boolean(self, type_):
+        return self.visit_BIT(type_)
+        
+    def visit_IMAGE(self, type_):
+        return "IMAGE"
+
+    def visit_BIT(self, type_):
+        return "BIT"
+
+    def visit_MONEY(self, type_):
+        return "MONEY"
+    
+    def visit_SMALLMONEY(self, type_):
+        return "SMALLMONEY"
+        
+    def visit_UNIQUEIDENTIFIER(self, type_):
+        return "UNIQUEIDENTIFIER"
+        
+colspecs = {
+    sqltypes.Binary : SybaseImage,
+    sqltypes.Boolean : SybaseBoolean,
+}
+
+ischema_names = {
+    'integer' : sqltypes.INTEGER,
+    'unsigned int' : sqltypes.Integer,
+    'unsigned smallint' : sqltypes.SmallInteger,
+    'unsigned bigint' : sqltypes.BigInteger,
+    'bigint': sqltypes.BIGINT,
+    'smallint' : sqltypes.SMALLINT,
+    'tinyint' : sqltypes.SmallInteger,
+    'varchar' : sqltypes.VARCHAR,
+    'long varchar' : sqltypes.Text,
+    'char' : sqltypes.CHAR,
+    'decimal' : sqltypes.DECIMAL,
+    'numeric' : sqltypes.NUMERIC,
+    'float' : sqltypes.FLOAT,
+    'double' : sqltypes.Numeric,
+    'binary' : sqltypes.Binary,
+    'long binary' : sqltypes.Binary,
+    'varbinary' : sqltypes.Binary,
+    'bit': SybaseBit,
+    'image' : SybaseImage,
+    'timestamp': sqltypes.TIMESTAMP,
+    'money': SybaseMoney,
+    'smallmoney': SybaseSmallMoney,
+    'uniqueidentifier': SybaseUniqueIdentifier,
+
+}
+
+
+class SybaseExecutionContext(default.DefaultExecutionContext):
+
+    def post_exec(self):
+        if self.compiled.isinsert:
+            table = self.compiled.statement.table
+            # get the inserted values of the primary key
+
+            # get any sequence IDs first (using @@identity)
+            self.cursor.execute("SELECT @@identity AS lastrowid")
+            row = self.cursor.fetchone()
+            lastrowid = int(row[0])
+            if lastrowid > 0:
+                # an IDENTITY was inserted, fetch it
+                # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
+                if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
+                    self._last_inserted_ids = [lastrowid]
+                else:
+                    self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
+
+
+class SybaseSQLCompiler(compiler.SQLCompiler):
+
+    extract_map = compiler.SQLCompiler.extract_map.copy()
+    extract_map.update ({
+        'doy': 'dayofyear',
+        'dow': 'weekday',
+        'milliseconds': 'millisecond'
+    })
+
+    def visit_mod(self, binary, **kw):
+        return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+    def bindparam_string(self, name):
+        res = super(SybaseSQLCompiler, self).bindparam_string(name)
+        if name.lower().startswith('literal'):
+            res = 'STRING(%s)' % res
+        return res
+
+    def get_select_precolumns(self, select):
+        s = select._distinct and "DISTINCT " or ""
+        if select._limit:
+            #if select._limit == 1:
+                #s += "FIRST "
+            #else:
+                #s += "TOP %s " % (select._limit,)
+            s += "TOP %s " % (select._limit,)
+        if select._offset:
+            if not select._limit:
+                # FIXME: sybase doesn't allow an offset without a limit
+                # so use a huge value for TOP here
+                s += "TOP 1000000 "
+            s += "START AT %s " % (select._offset+1,)
+        return s
+
+    def limit_clause(self, select):
+        # Limit in sybase is after the select keyword
+        return ""
+
+    def visit_binary(self, binary):
+        """Move bind parameters to the right-hand side of an operator, where possible."""
+        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
+            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
+        else:
+            return super(SybaseSQLCompiler, self).visit_binary(binary)
+
+    def label_select_column(self, select, column, asfrom):
+        if isinstance(column, expression.Function):
+            return column.label(None)
+        else:
+            return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
+
+    function_rewrites =  {'current_date': 'getdate',
+                         }
+    def visit_function(self, func):
+        func.name = self.function_rewrites.get(func.name, func.name)
+        res = super(SybaseSQLCompiler, self).visit_function(func)
+        if func.name.lower() == 'getdate':
+            # apply CAST operator
+            # FIXME: what about _pyodbc ?
+            cast = expression._Cast(func, SybaseDate_mxodbc)
+            # infinite recursion
+            # res = self.visit_cast(cast)
+            res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
+        return res
+
+    def visit_extract(self, extract):
+        field = self.extract_map.get(extract.field, extract.field)
+        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
+    def for_update_clause(self, select):
+        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+        return ''
+
+    def order_by_clause(self, select):
+        order_by = self.process(select._order_by_clause)
+
+        # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
+        if order_by and (not self.is_subquery() or select._limit):
+            return " ORDER BY " + order_by
+        else:
+            return ""
+
+
+class SybaseDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kwargs):
+
+        colspec = self.preparer.format_column(column)
+
+        if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
+                column.autoincrement and isinstance(column.type, sqltypes.Integer):
+            if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
+                column.sequence = schema.Sequence(column.name + '_seq')
+
+        if hasattr(column, 'sequence'):
+            column.table.has_sequence = column
+            #colspec += " numeric(30,0) IDENTITY"
+            colspec += " Integer IDENTITY"
+        else:
+            colspec += " " + self.dialect.type_compiler.process(column.type)
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
+        return colspec
+
+    def visit_drop_index(self, drop):
+        index = drop.element
+        return "\nDROP INDEX %s.%s" % (
+            self.preparer.quote_identifier(index.table.name),
+            self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
+            )
+
+class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = RESERVED_WORDS
+
+class SybaseDialect(default.DefaultDialect):
+    name = 'sybase'
+    supports_unicode_statements = False
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    colspecs = colspecs
+    ischema_names = ischema_names
+
+    type_compiler = SybaseTypeCompiler
+    statement_compiler = SybaseSQLCompiler
+    ddl_compiler = SybaseDDLCompiler
+    preparer = SybaseIdentifierPreparer
+
+    schema_name = "dba"
+
+    def __init__(self, **params):
+        super(SybaseDialect, self).__init__(**params)
+        self.text_as_varchar = False
+
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
+    def get_default_schema_name(self, connection):
+        return self.schema_name
+
+    def table_names(self, connection, schema):
+        """Ignore the schema and the charset for now."""
+        s = sql.select([tables.c.table_name],
+                       sql.not_(tables.c.table_name.like("SYS%")) and
+                       tables.c.creator >= 100
+                       )
+        rp = connection.execute(s)
+        return [row[0] for row in rp.fetchall()]
+
+    def has_table(self, connection, tablename, schema=None):
+        # FIXME: ignore schemas for sybase
+        s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
+
+        c = connection.execute(s)
+        row = c.fetchone()
+        return row is not None
+
+    def reflecttable(self, connection, table, include_columns):
+        # Get base columns
+        if table.schema is not None:
+            current_schema = table.schema
+        else:
+            current_schema = self.get_default_schema_name(connection)
+
+        s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
+
+        c = connection.execute(s)
+        found_table = False
+        # makes sure we append the columns in the correct order
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            found_table = True
+            (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
+                row[columns.c.column_name],
+                row[domains.c.domain_name],
+                row[columns.c.nulls] == 'Y',
+                row[columns.c.width],
+                row[domains.c.precision],
+                row[columns.c.scale],
+                row[columns.c.default],
+                row[columns.c.pkey] == 'Y',
+                row[columns.c.max_identity],
+                row[tables.c.table_id],
+                row[columns.c.column_id],
+            )
+            if include_columns and name not in include_columns:
+                continue
+
+            # FIXME: else problems with SybaseBinary(size)
+            if numericscale == 0:
+                numericscale = None
+
+            args = []
+            for a in (charlen, numericprec, numericscale):
+                if a is not None:
+                    args.append(a)
+            coltype = self.ischema_names.get(type, None)
+            if coltype == SybaseString and charlen == -1:
+                coltype = SybaseText()
+            else:
+                if coltype is None:
+                    util.warn("Did not recognize type '%s' of column '%s'" %
+                              (type, name))
+                    coltype = sqltypes.NULLTYPE
+                coltype = coltype(*args)
+            colargs = []
+            if default is not None:
+                colargs.append(schema.DefaultClause(sql.text(default)))
+
+            # any sequences ?
+            col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
+            if int(max_identity) > 0:
+                col.sequence = schema.Sequence(name + '_identity')
+                col.sequence.start = int(max_identity)
+                col.sequence.increment = 1
+
+            # append the column
+            table.append_column(col)
+
+        # any foreign key constraint for this table ?
+        # note: no multi-column foreign keys are considered
+        s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
+        c = connection.execute(s)
+        foreignKeys = {}
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            (foreign_table, foreign_column, primary_table, primary_column) = (
+                row[0], row[1], row[2], row[3],
+            )
+            if not primary_table in foreignKeys.keys():
+                foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
+            else:
+                foreignKeys[primary_table][0].append('%s'%(foreign_column))
+                foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
+        for primary_table in foreignKeys.iterkeys():
+            #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
+            table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
+
+        if not found_table:
+            raise exc.NoSuchTableError(table.name)
+
diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py
new file mode 100644 (file)
index 0000000..86a23d5
--- /dev/null
@@ -0,0 +1,10 @@
+from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
+from sqlalchemy.connectors.mxodbc import MxODBCConnector
+
+class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
+    pass
+
+class Sybase_mxodbc(MxODBCConnector, SybaseDialect):
+    execution_ctx_cls = SybaseExecutionContext_mxodbc
+
+dialect = Sybase_mxodbc
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py
new file mode 100644 (file)
index 0000000..61c6f32
--- /dev/null
@@ -0,0 +1,11 @@
+from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+
+class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
+    pass
+
+
+class Sybase_pyodbc(PyODBCConnector, SybaseDialect):
+    execution_ctx_cls = SybaseExecutionContext_pyodbc
+
+dialect = Sybase_pyodbc
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/sybase/schema.py b/lib/sqlalchemy/dialects/sybase/schema.py
new file mode 100644 (file)
index 0000000..15ac6b2
--- /dev/null
@@ -0,0 +1,51 @@
+from sqlalchemy import *
+
+ischema = MetaData()
+
+tables = Table("SYSTABLE", ischema,
+    Column("table_id", Integer, primary_key=True),
+    Column("file_id", SMALLINT),
+    Column("table_name", CHAR(128)),
+    Column("table_type", CHAR(10)),
+    Column("creator", Integer),
+    #schema="information_schema"
+    )
+
+domains = Table("SYSDOMAIN", ischema,
+    Column("domain_id", Integer, primary_key=True),
+    Column("domain_name", CHAR(128)),
+    Column("type_id", SMALLINT),
+    Column("precision", SMALLINT, quote=True),
+    #schema="information_schema"
+    )
+
+columns = Table("SYSCOLUMN", ischema,
+    Column("column_id", Integer, primary_key=True),
+    Column("table_id", Integer, ForeignKey(tables.c.table_id)),
+    Column("pkey", CHAR(1)),
+    Column("column_name", CHAR(128)),
+    Column("nulls", CHAR(1)),
+    Column("width", SMALLINT),
+    Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)),
+    # FIXME: should be mx.BIGINT
+    Column("max_identity", Integer),
+    # FIXME: should be mx.ODBC.Windows.LONGVARCHAR
+    Column("default", String),
+    Column("scale", Integer),
+    #schema="information_schema"
+    )
+
+foreignkeys = Table("SYSFOREIGNKEY", ischema,
+    Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True),
+    Column("foreign_key_id", SMALLINT, primary_key=True),
+    Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)),
+    #schema="information_schema"
+    )
+fkcols = Table("SYSFKCOL", ischema,
+    Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True),
+    Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True),
+    Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True),
+    Column("primary_column_id", Integer),
+    #schema="information_schema"
+    )
+
diff --git a/lib/sqlalchemy/dialects/type_migration_guidelines.txt b/lib/sqlalchemy/dialects/type_migration_guidelines.txt
new file mode 100644 (file)
index 0000000..8ed1a17
--- /dev/null
@@ -0,0 +1,145 @@
+Rules for Migrating TypeEngine classes to 0.6
+---------------------------------------------
+
+1. the TypeEngine classes are used for:
+
+    a. Specifying behavior which needs to occur for bind parameters
+    or result row columns.
+    
+    b. Specifying types that are entirely specific to the database
+    in use and have no analogue in the sqlalchemy.types package.
+    
+    c. Specifying types where there is an analogue in sqlalchemy.types,
+    but the database in use takes vendor-specific flags for those
+    types.
+
+    d. If a TypeEngine class doesn't provide any of this, it should be
+    *removed* from the dialect.
+    
+2. the TypeEngine classes are *no longer* used for generating DDL.  Dialects
+now have a TypeCompiler subclass which uses the same visit_XXX model as
+other compilers.   
+
+3. the "ischema_names" and "colspecs" dictionaries are now required members on
+the Dialect class.
+
+4. The names of types within dialects are now important.   If a dialect-specific type
+is a subclass of an existing generic type and is only provided for bind/result behavior, 
+the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, 
+end users would never need to use _PGNumeric directly.   However, if a dialect-specific 
+type is specifying a type *or* arguments that are not present generically, it should
+match the real name of the type on that backend, in uppercase.  E.g. postgresql.INET,
+mysql.ENUM, postgresql.ARRAY.  
+
+Or follow this handy flowchart:
+
+    is the type meant to provide bind/result                  is the type the same name as an
+    behavior to a generic type (i.e. MixedCase)  ---- no ---> UPPERCASE type in types.py ?
+    type in types.py ?                                          |                     |
+                    |                                           no                    yes
+                   yes                                          |                     |
+                    |                                           |             does your type need special
+                    |                                           +<--- yes --- behavior or arguments ?
+                    |                                           |                               |
+                    |                                           |                              no
+           name the type using                                  |                               |
+           _MixedCase, i.e.                                     v                               V
+           _OracleBoolean. it                          name the type                        don't make a
+           stays private to the dialect                identically as that                  type, make sure the dialect's
+           and is invoked *only* via                   within the DB,                       base.py imports the types.py
+           the colspecs dict.                          using UPPERCASE                      UPPERCASE name into its namespace
+                    |                                  (i.e. BIT, NCHAR, INTERVAL).
+                    |                                  Users can import it.
+                    |                                       |
+                    v                                       v
+           subclass the closest                        is the name of this type
+           MixedCase type types.py,                    identical to an UPPERCASE
+           i.e.                        <--- no ------- name in types.py ?
+           class _DateTime(types.DateTime),
+           class DATETIME2(types.DateTime),                   |
+           class BIT(types.TypeEngine).                      yes
+                                                              |
+                                                              v
+                                                        the type should
+                                                        subclass the   
+                                                        UPPERCASE      
+                                                        type in types.py
+                                                        (i.e. class BLOB(types.BLOB))
+
+
+Example 1.   pysqlite needs bind/result processing for the DateTime type in types.py, 
+which applies to all DateTimes and subclasses.   It's named _SLDateTime and 
+subclasses types.DateTime.
+
+Example 2.  MS-SQL has a TIME type which takes a non-standard "precision" argument
+that is rendered within DDL.   So it's named TIME in the MS-SQL dialect's base.py, 
+and subclasses types.TIME.  Users can then say mssql.TIME(precision=10).
+
+Example 3.  MS-SQL dialects also need special bind/result processing for date 
+But its DATE type doesn't render DDL differently than that of a plain 
+DATE, i.e. it takes no special arguments.  Therefore we are just adding behavior
+to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses
+types.Date.
+
+Example 4.  MySQL has a SET type, there's no analogue for this in types.py. So
+MySQL names it SET in the dialect's base.py, and it subclasses types.String, since 
+it ultimately deals with strings.
+
+Example 5.  Postgresql has a DATETIME type.  The DBAPIs handle dates correctly,
+and no special arguments are used in PG's DDL beyond what types.py provides.  
+Postgresql dialect therefore imports types.DATETIME into its base.py.
+
+Ideally one should be able to specify a schema using names imported completely from a 
+dialect, all matching the real name on that backend:
+
+   from sqlalchemy.dialects.postgresql import base as pg
+   
+   t = Table('mytable', metadata,
+              Column('id', pg.INTEGER, primary_key=True),
+              Column('name', pg.VARCHAR(300)),
+              Column('inetaddr', pg.INET)
+   )
+
+where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, 
+but the PG dialect makes them available in its own namespace.
+
+5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types
+linked to types specified in the dialect.   Again, if a type in the dialect does not
+specify any special behavior for bind_processor() or result_processor() and does not
+indicate a special type only available in this database, it must be *removed* from the 
+module and from this dictionary.
+
+6. "ischema_names" indicates string descriptions of types as returned from the database
+linked to TypeEngine classes.   
+
+    a. The string name should be matched to the most specific type possible within
+    sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
+    case it points to a dialect type.   *It doesn't matter* if the dialect has it's 
+    own subclass of that type with special bind/result behavior - reflect to the types.py
+    UPPERCASE type as much as possible.   With very few exceptions, all types
+    should reflect to an UPPERCASE type.
+    
+    b. If the dialect contains a matching dialect-specific type that takes extra arguments 
+    which the generic one does not, then point to the dialect-specific type.  E.g.
+    mssql.VARCHAR takes a "collation" parameter which should be preserved.
+    
+5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
+a subclass of compiler.GenericTypeCompiler.
+
+    a. your TypeCompiler class will receive generic and uppercase types from 
+    sqlalchemy.types.  Do not assume the presence of dialect-specific attributes on
+    these types. 
+    
+    b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
+    methods that produce a different DDL name.   Uppercase types don't do any kind of 
+    "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
+    all cases, regardless of whether or not that type is legal on the backend database.
+    
+    c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
+    arguments and flags to those types.  
+    
+    d. the visit_lowercase methods are overridden to provide an interpretation of a generic 
+    type.  E.g.  visit_binary() might be overridden to say "return self.visit_BIT(type_)".
+    
+    e. visit_lowercase methods should *never* render strings directly - it should always
+    be via calling a visit_UPPERCASE() method.
index bb2b1b5be4a21e4fc261a7d0bfece938dce4c6ea..694a2f71fa53d5a15fc320abcd0ea413aa6da4fe 100644 (file)
@@ -50,7 +50,9 @@ url.py
     within a URL.
 """
 
-import sqlalchemy.databases
+# not sure what this was used for
+#import sqlalchemy.databases  
+
 from sqlalchemy.engine.base import (
     BufferedColumnResultProxy,
     BufferedColumnRow,
@@ -66,9 +68,9 @@ from sqlalchemy.engine.base import (
     ResultProxy,
     RootTransaction,
     RowProxy,
-    SchemaIterator,
     Transaction,
-    TwoPhaseTransaction
+    TwoPhaseTransaction,
+    TypeCompiler
     )
 from sqlalchemy.engine import strategies
 from sqlalchemy import util
@@ -89,9 +91,9 @@ __all__ = (
     'ResultProxy',
     'RootTransaction',
     'RowProxy',
-    'SchemaIterator',
     'Transaction',
     'TwoPhaseTransaction',
+    'TypeCompiler',
     'create_engine',
     'engine_from_config',
     )
@@ -108,7 +110,7 @@ def create_engine(*args, **kwargs):
 
     The URL is a string in the form
     ``dialect://user:password@host/dbname[?key=value..]``, where
-    ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgres``,
+    ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgresql``,
     etc.  Alternatively, the URL can be an instance of
     :class:`~sqlalchemy.engine.url.URL`.
 
index 39085c359617067c8eef4ae953e37389a1d0728d..0a0b0ff0ca4bb4651f1975810c591bb3d9ba8bd1 100644 (file)
 Defines the basic components used to interface DB-API modules with
 higher-level statement-construction, connection-management, execution
 and result contexts.
-
 """
 
-__all__ = ['BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', 'Compiled', 'Connectable', 
-        'Connection', 'DefaultRunner', 'Dialect', 'Engine', 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 
-        'RootTransaction', 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', 'connection_memoize']
+__all__ = [
+    'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy',
+    'Compiled', 'Connectable', 'Connection', 'DefaultRunner', 'Dialect', 'Engine',
+    'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction',
+    'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction',
+    'connection_memoize']
 
-import inspect, StringIO
+import inspect, StringIO, sys
 from sqlalchemy import exc, schema, util, types, log
 from sqlalchemy.sql import expression
 
@@ -32,10 +34,14 @@ class Dialect(object):
     ExecutionContext, Compiled, DefaultGenerator, and TypeEngine.
 
     All Dialects implement the following attributes:
-    
+
     name
-      identifying name for the dialect (i.e. 'sqlite')
-      
+      identifying name for the dialect from a DBAPI-neutral point of view
+      (i.e. 'sqlite')
+
+    driver
+      identifying name for the dialect's DBAPI
+
     positional
       True if the paramstyle for this Dialect is positional.
 
@@ -51,20 +57,25 @@ class Dialect(object):
       type of encoding to use for unicode, usually defaults to
       'utf-8'.
 
-    schemagenerator
-      a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates
-      schemas.
-
-    schemadropper
-      a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas.
-
     defaultrunner
       a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes
       defaults.
 
     statement_compiler
-      a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL
-      statements
+      a :class:`~Compiled` class used to compile SQL statements
+
+    ddl_compiler
+      a :class:`~Compiled` class used to compile DDL statements
+
+    server_version_info
+      a tuple containing a version number for the DB backend in use.
+      This value is only available for supporting dialects, and only for
+      a dialect that's been associated with a connection pool via
+      create_engine() or otherwise had its ``initialize()`` method called
+      with a conneciton.
+
+    execution_ctx_cls
+      a :class:`ExecutionContext` class used to handle statement execution
 
     preparer
       a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
@@ -77,27 +88,38 @@ class Dialect(object):
       The maximum length of identifier names.
 
     supports_unicode_statements
-      Indicate whether the DB-API can receive SQL statements as Python unicode strings
+      Indicate whether the DB-API can receive SQL statements as Python
+      unicode strings
+
+    supports_unicode_binds
+      Indicate whether the DB-API can receive string bind parameters
+      as Python unicode strings
 
     supports_sane_rowcount
-      Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements.
+      Indicate whether the dialect properly implements rowcount for
+      ``UPDATE`` and ``DELETE`` statements.
 
     supports_sane_multi_rowcount
-      Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements
-      when executed via executemany.
-
-    preexecute_pk_sequences
-      Indicate if the dialect should pre-execute sequences on primary key
-      columns during an INSERT, if it's desired that the new row's primary key
-      be available after execution.
-
-    supports_pk_autoincrement
-      Indicates if the dialect should allow the database to passively assign
-      a primary key column value.
-
+      Indicate whether the dialect properly implements rowcount for
+      ``UPDATE`` and ``DELETE`` statements when executed via
+      executemany.
+
+    preexecute_autoincrement_sequences
+      True if 'implicit' primary key functions must be executed separately
+      in order to get their value.   This is currently oriented towards
+      Postgresql.
+      
+    implicit_returning
+      use RETURNING or equivalent during INSERT execution in order to load 
+      newly generated primary keys and other column defaults in one execution,
+      which are then available via inserted_primary_key.
+      If an insert statement has returning() specified explicitly, 
+      the "implicit" functionality is not used and inserted_primary_key
+      will not be available.
+      
     dbapi_type_map
       A mapping of DB-API type objects present in this Dialect's
-      DB-API implmentation mapped to TypeEngine implementations used
+      DB-API implementation mapped to TypeEngine implementations used
       by the dialect.
 
       This is used to apply types to result sets based on the DB-API
@@ -105,13 +127,15 @@ class Dialect(object):
       result sets against textual statements where no explicit
       typemap was present.
 
-    supports_default_values
-      Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported
+    colspecs
+      A dictionary of TypeEngine classes from sqlalchemy.types mapped
+      to subclasses that are specific to the dialect class.  This
+      dictionary is class-level only and is not accessed from the
+      dialect instance itself.
 
-    description_encoding
-      type of encoding to use for unicode when working with metadata
-      descriptions. If set to ``None`` no encoding will be done.
-      This usually defaults to 'utf-8'.
+    supports_default_values
+      Indicates if the construct ``INSERT INTO tablename DEFAULT
+      VALUES`` is supported
     """
 
     def create_connect_args(self, url):
@@ -124,25 +148,28 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    @classmethod
+    def type_descriptor(cls, typeobj):
+        """Transform a generic type to a dialect-specific type.
 
-    def type_descriptor(self, typeobj):
-        """Transform a generic type to a database-specific type.
-
-        Transforms the given :class:`~sqlalchemy.types.TypeEngine` instance
-        from generic to database-specific.
-
-        Subclasses will usually use the
+        Dialect classes will usually use the
         :func:`~sqlalchemy.types.adapt_type` method in the types module to
         make this job easy.
+
+        The returned result is cached *per dialect class* so can
+        contain no dialect-instance state.
         """
 
         raise NotImplementedError()
 
+    def initialize(self, connection):
+        """Called during strategized creation of the dialect with a connection.
 
-    def server_version_info(self, connection):
-        """Return a tuple of the database's version number."""
+        Allows dialects to configure options based on server version info or
+        other properties.
+        """
 
-        raise NotImplementedError()
+        pass
 
     def reflecttable(self, connection, table, include_columns=None):
         """Load table description from the database.
@@ -156,6 +183,133 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        """Return information about columns in `table_name`.
+
+        Given a :class:`~sqlalchemy.engine.Connection`, a string
+        `table_name`, and an optional string `schema`, return column
+        information as a list of dictionaries with these keys:
+
+        name
+          the column's name
+
+        type
+          [sqlalchemy.types#TypeEngine]
+
+        nullable
+          boolean
+
+        default
+          the column's default value
+
+        autoincrement
+          boolean
+
+        sequence
+          a dictionary of the form
+              {'name' : str, 'start' :int, 'increment': int}
+
+        Additional column attributes may be present.
+        """
+
+        raise NotImplementedError()
+
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        """Return information about primary keys in `table_name`.
+
+        Given a :class:`~sqlalchemy.engine.Connection`, a string
+        `table_name`, and an optional string `schema`, return primary
+        key information as a list of column names.
+        """
+
+        raise NotImplementedError()
+
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+        """Return information about foreign_keys in `table_name`.
+
+        Given a :class:`~sqlalchemy.engine.Connection`, a string
+        `table_name`, and an optional string `schema`, return foreign
+        key information as a list of dicts with these keys:
+
+        name
+          the constraint's name
+
+        constrained_columns
+          a list of column names that make up the foreign key
+
+        referred_schema
+          the name of the referred schema
+
+        referred_table
+          the name of the referred table
+
+        referred_columns
+          a list of column names in the referred table that correspond to
+          constrained_columns
+        """
+
+        raise NotImplementedError()
+
+    def get_table_names(self, connection, schema=None, **kw):
+        """Return a list of table names for `schema`."""
+
+        raise NotImplementedError
+
+    def get_view_names(self, connection, schema=None, **kw):
+        """Return a list of all view names available in the database.
+
+        schema:
+          Optional, retrieve names from a non-default schema.
+        """
+
+        raise NotImplementedError()
+
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        """Return view definition.
+
+        Given a :class:`~sqlalchemy.engine.Connection`, a string
+        `view_name`, and an optional string `schema`, return the view
+        definition.
+        """
+
+        raise NotImplementedError()
+
+    def get_indexes(self, connection, table_name, schema=None, **kw):
+        """Return information about indexes in `table_name`.
+
+        Given a :class:`~sqlalchemy.engine.Connection`, a string
+        `table_name` and an optional string `schema`, return index
+        information as a list of dictionaries with these keys:
+
+        name
+          the index's name
+
+        column_names
+          list of column names in order
+
+        unique
+          boolean
+        """
+
+        raise NotImplementedError()
+
+    def normalize_name(self, name):
+        """convert the given name to lowercase if it is detected as case insensitive.
+    
+        this method is only used if the dialect defines requires_name_normalize=True.
+
+        """
+        raise NotImplementedError()
+
+    def denormalize_name(self, name):
+        """convert the given name to a case insensitive identifier for the backend 
+        if it is an all-lowercase name.
+        
+        this method is only used if the dialect defines requires_name_normalize=True.
+
+        """
+        raise NotImplementedError()
+        
     def has_table(self, connection, table_name, schema=None):
         """Check the existence of a particular table in the database.
 
@@ -178,7 +332,11 @@ class Dialect(object):
         raise NotImplementedError()
 
     def get_default_schema_name(self, connection):
-        """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`."""
+        """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.
+        
+        DEPRECATED.  moving this towards dialect.default_schema_name (not complete).
+        
+        """
 
         raise NotImplementedError()
 
@@ -262,11 +420,14 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def visit_pool(self, pool):
+        """Executed after a pool is created."""
+
 
 class ExecutionContext(object):
     """A messenger object for a Dialect that corresponds to a single execution.
 
-    ExecutionContext should have these datamembers:
+    ExecutionContext should have these data members:
 
     connection
       Connection object which can be freely used by default value
@@ -308,20 +469,19 @@ class ExecutionContext(object):
       True if the statement is an UPDATE.
 
     should_autocommit
-      True if the statement is a "committable" statement
+      True if the statement is a "committable" statement.
 
     postfetch_cols
-     a list of Column objects for which a server-side default
-     or inline SQL expression value was fired off.  applies to inserts and updates.
-
-
+      a list of Column objects for which a server-side default or
+      inline SQL expression value was fired off.  Applies to inserts
+      and updates.
     """
 
     def create_cursor(self):
         """Return a new cursor generated from this ExecutionContext's connection.
 
         Some dialects may wish to change the behavior of
-        connection.cursor(), such as postgres which may return a PG
+        connection.cursor(), such as postgresql which may return a PG
         "server side" cursor.
         """
 
@@ -357,22 +517,11 @@ class ExecutionContext(object):
 
     def handle_dbapi_exception(self, e):
         """Receive a DBAPI exception which occured upon execute, result fetch, etc."""
-        
-        raise NotImplementedError()
-        
-    def should_autocommit_text(self, statement):
-        """Parse the given textual statement and return True if it refers to a "committable" statement"""
 
         raise NotImplementedError()
 
-    def last_inserted_ids(self):
-        """Return the list of the primary key values for the last insert statement executed.
-
-        This does not apply to straight textual clauses; only to
-        ``sql.Insert`` objects compiled against a ``schema.Table``
-        object.  The order of items in the list is the same as that of
-        the Table's 'primary_key' attribute.
-        """
+    def should_autocommit_text(self, statement):
+        """Parse the given textual statement and return True if it refers to a "committable" statement"""
 
         raise NotImplementedError()
 
@@ -401,7 +550,7 @@ class ExecutionContext(object):
 
 
 class Compiled(object):
-    """Represent a compiled SQL expression.
+    """Represent a compiled SQL or DDL expression.
 
     The ``__str__`` method of the ``Compiled`` object should produce
     the actual text of the statement.  ``Compiled`` objects are
@@ -413,53 +562,49 @@ class Compiled(object):
     defaults.
     """
 
-    def __init__(self, dialect, statement, column_keys=None, bind=None):
+    def __init__(self, dialect, statement, bind=None):
         """Construct a new ``Compiled`` object.
 
-        dialect
-          ``Dialect`` to compile against.
-
-        statement
-          ``ClauseElement`` to be compiled.
+        :param dialect: ``Dialect`` to compile against.
 
-        column_keys
-          a list of column names to be compiled into an INSERT or UPDATE
-          statement.
+        :param statement: ``ClauseElement`` to be compiled.
 
-        bind
-          Optional Engine or Connection to compile this statement against.
-          
+        :param bind: Optional Engine or Connection to compile this statement against.
         """
+
         self.dialect = dialect
         self.statement = statement
-        self.column_keys = column_keys
         self.bind = bind
         self.can_execute = statement.supports_execution
 
     def compile(self):
         """Produce the internal string representation of this element."""
 
-        raise NotImplementedError()
+        self.string = self.process(self.statement)
 
-    def __str__(self):
-        """Return the string text of the generated SQL statement."""
+    def process(self, obj, **kwargs):
+        return obj._compiler_dispatch(self, **kwargs)
 
-        raise NotImplementedError()
+    def __str__(self):
+        """Return the string text of the generated SQL or DDL."""
 
-    @util.deprecated('Deprecated. Use construct_params(). '
-                     '(supports Unicode key names.)')
-    def get_params(self, **params):
-        return self.construct_params(params)
+        return self.string or ''
 
-    def construct_params(self, params):
+    def construct_params(self, params=None):
         """Return the bind params for this compiled object.
 
-        `params` is a dict of string/object pairs whos
-        values will override bind values compiled in
-        to the statement.
+        :param params: a dict of string/object pairs whos values will
+                       override bind values compiled in to the
+                       statement.
         """
+
         raise NotImplementedError()
 
+    @property
+    def params(self):
+        """Return the bind params for this compiled object."""
+        return self.construct_params()
+
     def execute(self, *multiparams, **params):
         """Execute this compiled object."""
 
@@ -474,12 +619,24 @@ class Compiled(object):
         return self.execute(*multiparams, **params).scalar()
 
 
+class TypeCompiler(object):
+    """Produces DDL specification for TypeEngine objects."""
+
+    def __init__(self, dialect):
+        self.dialect = dialect
+
+    def process(self, type_):
+        return type_._compiler_dispatch(self)
+
+
 class Connectable(object):
     """Interface for an object which supports execution of SQL constructs.
-    
+
     The two implementations of ``Connectable`` are :class:`Connection` and
     :class:`Engine`.
-    
+
+    Connectable must also implement the 'dialect' member which references a
+    :class:`Dialect` instance.
     """
 
     def contextual_connect(self):
@@ -503,6 +660,7 @@ class Connectable(object):
     def _execute_clauseelement(self, elem, multiparams=None, params=None):
         raise NotImplementedError()
 
+
 class Connection(Connectable):
     """Provides high-level functionality for a wrapped DB-API connection.
 
@@ -514,7 +672,6 @@ class Connection(Connectable):
 
     .. index::
       single: thread safety; Connection
-
     """
 
     def __init__(self, engine, connection=None, close_with_result=False,
@@ -524,7 +681,6 @@ class Connection(Connectable):
         Connection objects are typically constructed by an
         :class:`~sqlalchemy.engine.Engine`, see the ``connect()`` and
         ``contextual_connect()`` methods of Engine.
-        
         """
 
         self.engine = engine
@@ -534,7 +690,7 @@ class Connection(Connectable):
         self.__savepoint_seq = 0
         self.__branch = _branch
         self.__invalid = False
-        
+
     def _branch(self):
         """Return a new Connection which references this Connection's
         engine and connection; but does not have close_with_result enabled,
@@ -542,8 +698,8 @@ class Connection(Connectable):
 
         This is used to execute "sub" statements within a single execution,
         usually an INSERT statement.
-        
         """
+
         return self.engine.Connection(self.engine, self.__connection, _branch=True)
 
     @property
@@ -554,13 +710,13 @@ class Connection(Connectable):
 
     @property
     def closed(self):
-        """return True if this connection is closed."""
+        """Return True if this connection is closed."""
 
         return not self.__invalid and '_Connection__connection' not in self.__dict__
 
     @property
     def invalidated(self):
-        """return True if this connection was invalidated."""
+        """Return True if this connection was invalidated."""
 
         return self.__invalid
 
@@ -583,13 +739,14 @@ class Connection(Connectable):
     def should_close_with_result(self):
         """Indicates if this Connection should be closed when a corresponding
         ResultProxy is closed; this is essentially an auto-release mode.
-        
         """
+
         return self.__close_with_result
 
     @property
     def info(self):
         """A collection of per-DB-API connection instance properties."""
+
         return self.connection.info
 
     def connect(self):
@@ -598,8 +755,8 @@ class Connection(Connectable):
         This ``Connectable`` interface method returns self, allowing
         Connections to be used interchangably with Engines in most
         situations that require a bind.
-
         """
+
         return self
 
     def contextual_connect(self, **kwargs):
@@ -608,8 +765,8 @@ class Connection(Connectable):
         This ``Connectable`` interface method returns self, allowing
         Connections to be used interchangably with Engines in most
         situations that require a bind.
-
         """
+
         return self
 
     def invalidate(self, exception=None):
@@ -627,8 +784,8 @@ class Connection(Connectable):
         rolled back before a reconnect on this Connection can proceed.  This
         is to prevent applications from accidentally continuing their transactional
         operations in a non-transactional state.
-
         """
+
         if self.closed:
             raise exc.InvalidRequestError("This Connection is closed")
 
@@ -651,8 +808,8 @@ class Connection(Connectable):
         :class:`~sqlalchemy.interfaces.PoolListener` for a mechanism to modify
         connection state when connections leave and return to their
         connection pool.
-
         """
+
         self.__connection.detach()
 
     def begin(self):
@@ -663,8 +820,8 @@ class Connection(Connectable):
         outermost transaction may ``commit``.  Calls to ``commit`` on
         inner transactions are ignored.  Any transaction in the
         hierarchy may ``rollback``, however.
-
         """
+
         if self.__transaction is None:
             self.__transaction = RootTransaction(self)
         else:
@@ -690,9 +847,8 @@ class Connection(Connectable):
     def begin_twophase(self, xid=None):
         """Begin a two-phase or XA transaction and return a Transaction handle.
 
-        xid
-          the two phase transaction id.  If not supplied, a random id
-          will be generated.
+        :param xid: the two phase transaction id.  If not supplied, a random id
+                    will be generated.
         """
 
         if self.__transaction is not None:
@@ -813,9 +969,6 @@ class Connection(Connectable):
 
         return self.execute(object, *multiparams, **params).scalar()
 
-    def statement_compiler(self, statement, **kwargs):
-        return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
     def execute(self, object, *multiparams, **params):
         """Executes and returns a ResultProxy."""
 
@@ -826,11 +979,12 @@ class Connection(Connectable):
             raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object)))
 
     def __distill_params(self, multiparams, params):
-        """given arguments from the calling form *multiparams, **params, return a list
+        """Given arguments from the calling form *multiparams, **params, return a list
         of bind parameter structures, usually a list of dictionaries.
 
-        in the case of 'raw' execution which accepts positional parameters,
-        it may be a list of tuples or lists."""
+        In the case of 'raw' execution which accepts positional parameters,
+        it may be a list of tuples or lists.
+        """
 
         if not multiparams:
             if params:
@@ -858,7 +1012,19 @@ class Connection(Connectable):
         return self._execute_clauseelement(func.select(), multiparams, params)
 
     def _execute_default(self, default, multiparams, params):
-        return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
+        ret = self.engine.dialect.\
+                    defaultrunner(self.__create_execution_context()).\
+                    traverse_single(default)
+        if self.__close_with_result:
+            self.close()
+        return ret
+
+    def _execute_ddl(self, ddl, params, multiparams):
+        context = self.__create_execution_context(
+                        compiled_ddl=ddl.compile(dialect=self.dialect),
+                        parameters=None
+                    )
+        return self.__execute_context(context)
 
     def _execute_clauseelement(self, elem, multiparams, params):
         params = self.__distill_params(multiparams, params)
@@ -868,7 +1034,7 @@ class Connection(Connectable):
             keys = []
 
         context = self.__create_execution_context(
-                        compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), 
+                        compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1),
                         parameters=params
                     )
         return self.__execute_context(context)
@@ -877,7 +1043,7 @@ class Connection(Connectable):
         """Execute a sql.Compiled object."""
 
         context = self.__create_execution_context(
-                    compiled=compiled, 
+                    compiled_sql=compiled,
                     parameters=self.__distill_params(multiparams, params)
                 )
         return self.__execute_context(context)
@@ -886,38 +1052,42 @@ class Connection(Connectable):
         parameters = self.__distill_params(multiparams, params)
         context = self.__create_execution_context(statement=statement, parameters=parameters)
         return self.__execute_context(context)
-    
+
     def __execute_context(self, context):
         if context.compiled:
             context.pre_exec()
+            
         if context.executemany:
             self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
         else:
             self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
+            
         if context.compiled:
             context.post_exec()
+            
+            if context.isinsert and not context.executemany:
+                context.post_insert()
+            
         if context.should_autocommit and not self.in_transaction():
             self._commit_impl()
-        return context.get_result_proxy()
+            
+        return context.get_result_proxy()._autoclose()
         
-    def _execute_ddl(self, ddl, params, multiparams):
-        if params:
-            schema_item, params = params[0], params[1:]
-        else:
-            schema_item = None
-        return ddl(None, schema_item, self, *params, **multiparams)
-
     def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
-            raise exc.DBAPIError.instance(None, None, e)
+            # Py3K
+            #raise exc.DBAPIError.instance(statement, parameters, e) from e
+            # Py2K
+            raise exc.DBAPIError.instance(statement, parameters, e), None, sys.exc_info()[2]
+            # end Py2K
         self._reentrant_error = True
         try:
             if not isinstance(e, self.dialect.dbapi.Error):
                 return
-                
+
             if context:
                 context.handle_dbapi_exception(e)
-                
+
             is_disconnect = self.dialect.is_disconnect(e)
             if is_disconnect:
                 self.invalidate(e)
@@ -928,7 +1098,12 @@ class Connection(Connectable):
                 self._autorollback()
                 if self.__close_with_result:
                     self.close()
-            raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+            # Py3K
+            #raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) from e
+            # Py2K
+            raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect), None, sys.exc_info()[2]
+            # end Py2K
+            
         finally:
             del self._reentrant_error
 
@@ -966,7 +1141,7 @@ class Connection(Connectable):
         expression.ClauseElement: _execute_clauseelement,
         Compiled: _execute_compiled,
         schema.SchemaItem: _execute_default,
-        schema.DDL: _execute_ddl,
+        schema.DDLElement: _execute_ddl,
         basestring: _execute_text
     }
 
@@ -991,6 +1166,7 @@ class Connection(Connectable):
     def run_callable(self, callable_):
         return callable_(self)
 
+
 class Transaction(object):
     """Represent a Transaction in progress.
 
@@ -998,14 +1174,13 @@ class Transaction(object):
 
     .. index::
       single: thread safety; Transaction
-
     """
 
     def __init__(self, connection, parent):
         self.connection = connection
         self._parent = parent or self
         self.is_active = True
-    
+
     def close(self):
         """Close this transaction.
 
@@ -1016,6 +1191,7 @@ class Transaction(object):
         This is used to cancel a Transaction without affecting the scope of
         an enclosing transaction.
         """
+
         if not self._parent.is_active:
             return
         if self._parent is self:
@@ -1048,6 +1224,7 @@ class Transaction(object):
         else:
             self.rollback()
 
+
 class RootTransaction(Transaction):
     def __init__(self, connection):
         super(RootTransaction, self).__init__(connection, None)
@@ -1059,6 +1236,7 @@ class RootTransaction(Transaction):
     def _do_commit(self):
         self.connection._commit_impl()
 
+
 class NestedTransaction(Transaction):
     def __init__(self, connection, parent):
         super(NestedTransaction, self).__init__(connection, parent)
@@ -1070,6 +1248,7 @@ class NestedTransaction(Transaction):
     def _do_commit(self):
         self.connection._release_savepoint_impl(self._savepoint, self._parent)
 
+
 class TwoPhaseTransaction(Transaction):
     def __init__(self, connection, xid):
         super(TwoPhaseTransaction, self).__init__(connection, None)
@@ -1089,9 +1268,10 @@ class TwoPhaseTransaction(Transaction):
     def commit(self):
         self.connection._commit_twophase_impl(self.xid, self._is_prepared)
 
+
 class Engine(Connectable):
     """
-    Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` 
+    Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect`
     together to provide a source of database connectivity and behavior.
 
     """
@@ -1111,9 +1291,15 @@ class Engine(Connectable):
     @property
     def name(self):
         "String name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``."
-        
+
         return self.dialect.name
 
+    @property
+    def driver(self):
+        "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``."
+
+        return self.dialect.driver
+
     echo = log.echo_property()
 
     def __repr__(self):
@@ -1126,12 +1312,16 @@ class Engine(Connectable):
     def create(self, entity, connection=None, **kwargs):
         """Create a table or index within this engine's database connection given a schema.Table object."""
 
-        self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs)
+        from sqlalchemy.engine import ddl
+
+        self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs)
 
     def drop(self, entity, connection=None, **kwargs):
         """Drop a table or index within this engine's database connection given a schema.Table object."""
 
-        self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
+        from sqlalchemy.engine import ddl
+
+        self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs)
 
     def _execute_default(self, default):
         connection = self.contextual_connect()
@@ -1212,9 +1402,6 @@ class Engine(Connectable):
         connection = self.contextual_connect(close_with_result=True)
         return connection._execute_compiled(compiled, multiparams, params)
 
-    def statement_compiler(self, statement, **kwargs):
-        return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
     def connect(self, **kwargs):
         """Return a newly allocated Connection object."""
 
@@ -1231,12 +1418,10 @@ class Engine(Connectable):
     def table_names(self, schema=None, connection=None):
         """Return a list of all table names available in the database.
 
-        schema:
-          Optional, retrieve names from a non-default schema.
+        :param schema: Optional, retrieve names from a non-default schema.
 
-        connection:
-          Optional, use a specified connection.  Default is the
-          ``contextual_connect`` for this ``Engine``.
+        :param connection: Optional, use a specified connection.  Default is the
+                           ``contextual_connect`` for this ``Engine``.
         """
 
         if connection is None:
@@ -1275,22 +1460,24 @@ class Engine(Connectable):
 
         return self.pool.unique_connection()
 
+
 def _proxy_connection_cls(cls, proxy):
     class ProxyConnection(cls):
         def execute(self, object, *multiparams, **params):
             return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params)
+
         def _execute_clauseelement(self, elem, multiparams=None, params=None):
             return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {}))
-            
+
         def _cursor_execute(self, cursor, statement, parameters, context=None):
             return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False)
+
         def _cursor_executemany(self, cursor, statement, parameters, context=None):
             return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True)
 
     return ProxyConnection
 
+
 class RowProxy(object):
     """Proxy a single cursor row for a parent ResultProxy.
 
@@ -1302,7 +1489,7 @@ class RowProxy(object):
     """
 
     __slots__ = ['__parent', '__row']
-    
+
     def __init__(self, parent, row):
         """RowProxy objects are constructed by ResultProxy objects."""
 
@@ -1327,7 +1514,7 @@ class RowProxy(object):
             yield self.__parent._get_col(self.__row, i)
 
     __hash__ = None
-    
+
     def __eq__(self, other):
         return ((other is self) or
                 (other == tuple(self.__parent._get_col(self.__row, key)
@@ -1362,18 +1549,19 @@ class RowProxy(object):
         """Return the list of keys as strings represented by this RowProxy."""
 
         return self.__parent.keys
-    
+
     def iterkeys(self):
         return iter(self.__parent.keys)
-        
+
     def values(self):
         """Return the values represented by this RowProxy as a list."""
 
         return list(self)
-    
+
     def itervalues(self):
         return iter(self)
 
+
 class BufferedColumnRow(RowProxy):
     def __init__(self, parent, row):
         row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
@@ -1403,9 +1591,8 @@ class ResultProxy(object):
     """
 
     _process_row = RowProxy
-
+    
     def __init__(self, context):
-        """ResultProxy objects are constructed via the execute() method on SQLEngine."""
         self.context = context
         self.dialect = context.dialect
         self.closed = False
@@ -1413,40 +1600,81 @@ class ResultProxy(object):
         self.connection = context.root_connection
         self._echo = context.engine._should_log_info
         self._init_metadata()
-    
-    @property
+            
+    @util.memoized_property
     def rowcount(self):
-        if self._rowcount is None:
-            return self.context.get_rowcount()
-        else:
-            return self._rowcount
+        """Return the 'rowcount' for this result.
+        
+        The 'rowcount' reports the number of rows affected
+        by an UPDATE or DELETE statement.  It has *no* other
+        uses and is not intended to provide the number of rows
+        present from a SELECT.
+        
+        Additionally, this value is only meaningful if the
+        dialect's supports_sane_rowcount flag is True for
+        single-parameter executions, or supports_sane_multi_rowcount
+        is true for multiple parameter executions - otherwise
+        results are undefined.
+        
+        rowcount may not work at this time for a statement
+        that uses ``returning()``.
+        
+        """
+        return self.context.rowcount
 
     @property
     def lastrowid(self):
+        """return the 'lastrowid' accessor on the DBAPI cursor.
+        
+        This is a DBAPI specific method and is only functional
+        for those backends which support it, for statements
+        where it is appropriate.  It's behavior is not 
+        consistent across backends.
+        
+        Usage of this method is normally unnecessary; the
+        inserted_primary_key method provides a
+        tuple of primary key values for a newly inserted row,
+        regardless of database backend.
+        
+        """
         return self.cursor.lastrowid
 
     @property
     def out_parameters(self):
         return self.context.out_parameters
-
+    
+    def _cursor_description(self):
+        return self.cursor.description
+            
+    def _autoclose(self):
+        if self.context.isinsert:
+            if self.context._is_implicit_returning:
+                self.context._fetch_implicit_returning(self)
+                self.close()
+            elif not self.context._is_explicit_returning:
+                self.close()
+        elif self._metadata is None:
+            # no results, get rowcount 
+            # (which requires open cursor on some DB's such as firebird),
+            self.rowcount
+            self.close() # autoclose
+            
+        return self
+    
+            
     def _init_metadata(self):
-        metadata = self.cursor.description
+        self._metadata = metadata = self._cursor_description()
         if metadata is None:
-            # no results, get rowcount (which requires open cursor on some DB's such as firebird),
-            # then close
-            self._rowcount = self.context.get_rowcount()
-            self.close()
             return
-            
-        self._rowcount = None
+        
         self._props = util.populate_column_dict(None)
         self._props.creator = self.__key_fallback()
         self.keys = []
 
         typemap = self.dialect.dbapi_type_map
 
-        for i, item in enumerate(metadata):
-            colname = item[0]
+        for i, (colname, coltype) in enumerate(m[0:2] for m in metadata):
+
             if self.dialect.description_encoding:
                 colname = colname.decode(self.dialect.description_encoding)
 
@@ -1461,9 +1689,9 @@ class ResultProxy(object):
                 try:
                     (name, obj, type_) = self.context.result_map[colname.lower()]
                 except KeyError:
-                    (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+                    (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
             else:
-                (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+                (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
 
             rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
 
@@ -1474,7 +1702,10 @@ class ResultProxy(object):
             if origname:
                 if self._props.setdefault(origname.lower(), rec) is not rec:
                     self._props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0)
-
+            
+            if self.dialect.requires_name_normalize:
+                colname = self.dialect.normalize_name(colname)
+                
             self.keys.append(colname)
             self._props[i] = rec
             if obj:
@@ -1484,11 +1715,11 @@ class ResultProxy(object):
         if self._echo:
             self.context.engine.logger.debug(
                 "Col " + repr(tuple(x[0] for x in metadata)))
-    
+
     def __key_fallback(self):
         # create a closure without 'self' to avoid circular references
         props = self._props
-        
+
         def fallback(key):
             if isinstance(key, basestring):
                 key = key.lower()
@@ -1515,19 +1746,22 @@ class ResultProxy(object):
 
     def close(self):
         """Close this ResultProxy.
-        
+
         Closes the underlying DBAPI cursor corresponding to the execution.
+        
+        Note that any data cached within this ResultProxy is still available.
+        For some types of results, this may include buffered rows.
 
         If this ResultProxy was generated from an implicit execution,
         the underlying Connection will also be closed (returns the
         underlying DBAPI connection to the connection pool.)
 
         This method is called automatically when:
-        
-            * all result rows are exhausted using the fetchXXX() methods.
-            * cursor.description is None.
-        
+
+        * all result rows are exhausted using the fetchXXX() methods.
+        * cursor.description is None.
         """
+
         if not self.closed:
             self.closed = True
             self.cursor.close()
@@ -1550,53 +1784,66 @@ class ResultProxy(object):
                 raise StopIteration
             else:
                 yield row
-
-    def last_inserted_ids(self):
-        """Return ``last_inserted_ids()`` from the underlying ExecutionContext.
-
-        See ExecutionContext for details.
+    
+    @util.memoized_property
+    def inserted_primary_key(self):
+        """Return the primary key for the row just inserted.
+        
+        This only applies to single row insert() constructs which
+        did not explicitly specify returning().
 
         """
-        return self.context.last_inserted_ids()
+        if not self.context.isinsert:
+            raise exc.InvalidRequestError("Statement is not an insert() expression construct.")
+        elif self.context._is_explicit_returning:
+            raise exc.InvalidRequestError("Can't call inserted_primary_key when returning() is used.")
+            
+        return self.context._inserted_primary_key
 
+    @util.deprecated("Use inserted_primary_key")
+    def last_inserted_ids(self):
+        """deprecated.  use inserted_primary_key."""
+        
+        return self.inserted_primary_key
+        
     def last_updated_params(self):
         """Return ``last_updated_params()`` from the underlying ExecutionContext.
 
         See ExecutionContext for details.
-
         """
+
         return self.context.last_updated_params()
 
     def last_inserted_params(self):
         """Return ``last_inserted_params()`` from the underlying ExecutionContext.
 
         See ExecutionContext for details.
-
         """
+
         return self.context.last_inserted_params()
 
     def lastrow_has_defaults(self):
         """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext.
 
         See ExecutionContext for details.
-        
         """
+
         return self.context.lastrow_has_defaults()
 
     def postfetch_cols(self):
         """Return ``postfetch_cols()`` from the underlying ExecutionContext.
 
         See ExecutionContext for details.
-        
         """
+
         return self.context.postfetch_cols
-    
+
     def prefetch_cols(self):
         return self.context.prefetch_cols
-        
+
     def supports_sane_rowcount(self):
         """Return ``supports_sane_rowcount`` from the dialect."""
-        
+
         return self.dialect.supports_sane_rowcount
 
     def supports_sane_multi_rowcount(self):
@@ -1643,7 +1890,12 @@ class ResultProxy(object):
             raise
 
     def fetchmany(self, size=None):
-        """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``."""
+        """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``.
+        
+        If rows are present, the cursor remains open after this is called.
+        Else the cursor is automatically closed and an empty list is returned.
+        
+        """
 
         try:
             process_row = self._process_row
@@ -1656,7 +1908,13 @@ class ResultProxy(object):
             raise
 
     def fetchone(self):
-        """Fetch one row, just like DB-API ``cursor.fetchone()``."""
+        """Fetch one row, just like DB-API ``cursor.fetchone()``.
+        
+        If a row is present, the cursor remains open after this is called.
+        Else the cursor is automatically closed and None is returned.
+        
+        """
+
         try:
             row = self._fetchone_impl()
             if row is not None:
@@ -1668,21 +1926,38 @@ class ResultProxy(object):
             self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
             raise
 
-    def scalar(self):
-        """Fetch the first column of the first row, and close the result set."""
+    def first(self):
+        """Fetch the first row and then close the result set unconditionally.
+        
+        Returns None if no row is present.
+        
+        """
         try:
             row = self._fetchone_impl()
         except Exception, e:
             self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
             raise
-            
+
         try:
             if row is not None:
-                return self._process_row(self, row)[0]
+                return self._process_row(self, row)
             else:
                 return None
         finally:
             self.close()
+        
+        
+    def scalar(self):
+        """Fetch the first column of the first row, and close the result set.
+        
+        Returns None if no row is present.
+        
+        """
+        row = self.first()
+        if row is not None:
+            return row[0]
+        else:
+            return None
 
 class BufferedRowResultProxy(ResultProxy):
     """A ResultProxy with row buffering behavior.
@@ -1697,7 +1972,6 @@ class BufferedRowResultProxy(ResultProxy):
     The pre-fetching behavior fetches only one row initially, and then
     grows its buffer size by a fixed amount with each successive need
     for additional rows up to a size of 100.
-    
     """
 
     def _init_metadata(self):
@@ -1740,7 +2014,44 @@ class BufferedRowResultProxy(ResultProxy):
         return result
 
     def _fetchall_impl(self):
-        return self.__rowbuffer + list(self.cursor.fetchall())
+        ret = self.__rowbuffer + list(self.cursor.fetchall())
+        self.__rowbuffer[:] = []
+        return ret
+
+class FullyBufferedResultProxy(ResultProxy):
+    """A result proxy that buffers rows fully upon creation.
+    
+    Used for operations where a result is to be delivered
+    after the database conversation can not be continued,
+    such as MSSQL INSERT...OUTPUT after an autocommit.
+    
+    """
+    def _init_metadata(self):
+        super(FullyBufferedResultProxy, self)._init_metadata()
+        self.__rowbuffer = self._buffer_rows()
+        
+    def _buffer_rows(self):
+        return self.cursor.fetchall()
+        
+    def _fetchone_impl(self):
+        if self.__rowbuffer:
+            return self.__rowbuffer.pop(0)
+        else:
+            return None
+
+    def _fetchmany_impl(self, size=None):
+        result = []
+        for x in range(0, size):
+            row = self._fetchone_impl()
+            if row is None:
+                break
+            result.append(row)
+        return result
+
+    def _fetchall_impl(self):
+        ret = self.__rowbuffer
+        self.__rowbuffer = []
+        return ret
 
 class BufferedColumnResultProxy(ResultProxy):
     """A ResultProxy with column buffering behavior.
@@ -1791,28 +2102,6 @@ class BufferedColumnResultProxy(ResultProxy):
         return l
 
 
-class SchemaIterator(schema.SchemaVisitor):
-    """A visitor that can gather text into a buffer and execute the contents of the buffer."""
-
-    def __init__(self, connection):
-        """Construct a new SchemaIterator."""
-        
-        self.connection = connection
-        self.buffer = StringIO.StringIO()
-
-    def append(self, s):
-        """Append content to the SchemaIterator's query buffer."""
-
-        self.buffer.write(s)
-
-    def execute(self):
-        """Execute the contents of the SchemaIterator's buffer."""
-
-        try:
-            return self.connection.execute(self.buffer.getvalue())
-        finally:
-            self.buffer.truncate(0)
-
 class DefaultRunner(schema.SchemaVisitor):
     """A visitor which accepts ColumnDefault objects, produces the
     dialect-specific SQL corresponding to their execution, and
@@ -1821,7 +2110,6 @@ class DefaultRunner(schema.SchemaVisitor):
     DefaultRunners are used internally by Engines and Dialects.
     Specific database modules should provide their own subclasses of
     DefaultRunner to allow database-specific behavior.
-
     """
 
     def __init__(self, context):
@@ -1854,7 +2142,7 @@ class DefaultRunner(schema.SchemaVisitor):
 
     def execute_string(self, stmt, params=None):
         """execute a string statement, using the raw cursor, and return a scalar result."""
-        
+
         conn = self.context._connection
         if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
             stmt = stmt.encode(self.dialect.encoding)
@@ -1883,8 +2171,8 @@ def connection_memoize(key):
 
     Only applicable to functions which take no arguments other than a
     connection.  The memo will be stored in ``connection.info[key]``.
-
     """
+
     @util.decorator
     def decorated(fn, self, connection):
         connection = connection.connect()
diff --git a/lib/sqlalchemy/engine/ddl.py b/lib/sqlalchemy/engine/ddl.py
new file mode 100644 (file)
index 0000000..6e7253e
--- /dev/null
@@ -0,0 +1,128 @@
+# engine/ddl.py
+# Copyright (C) 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Routines to handle CREATE/DROP workflow."""
+
+from sqlalchemy import engine, schema
+from sqlalchemy.sql import util as sql_util
+
+
+class DDLBase(schema.SchemaVisitor):
+    def __init__(self, connection):
+        self.connection = connection
+
+class SchemaGenerator(DDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+        super(SchemaGenerator, self).__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+        self.tables = tables and set(tables) or None
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def _can_create(self, table):
+        self.dialect.validate_identifier(table.name)
+        if table.schema:
+            self.dialect.validate_identifier(table.schema)
+        return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+    def visit_metadata(self, metadata):
+        if self.tables:
+            tables = self.tables
+        else:
+            tables = metadata.tables.values()
+        collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
+        
+        for listener in metadata.ddl_listeners['before-create']:
+            listener('before-create', metadata, self.connection, tables=collection)
+            
+        for table in collection:
+            self.traverse_single(table)
+
+        for listener in metadata.ddl_listeners['after-create']:
+            listener('after-create', metadata, self.connection, tables=collection)
+
+    def visit_table(self, table):
+        for listener in table.ddl_listeners['before-create']:
+            listener('before-create', table, self.connection)
+
+        for column in table.columns:
+            if column.default is not None:
+                self.traverse_single(column.default)
+
+        self.connection.execute(schema.CreateTable(table))
+
+        if hasattr(table, 'indexes'):
+            for index in table.indexes:
+                self.traverse_single(index)
+
+        for listener in table.ddl_listeners['after-create']:
+            listener('after-create', table, self.connection)
+
+    def visit_sequence(self, sequence):
+        if self.dialect.supports_sequences:
+            if ((not self.dialect.sequences_optional or
+                 not sequence.optional) and
+                (not self.checkfirst or
+                 not self.dialect.has_sequence(self.connection, sequence.name))):
+                self.connection.execute(schema.CreateSequence(sequence))
+
+    def visit_index(self, index):
+        self.connection.execute(schema.CreateIndex(index))
+
+
+class SchemaDropper(DDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+        super(SchemaDropper, self).__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+        self.tables = tables
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def visit_metadata(self, metadata):
+        if self.tables:
+            tables = self.tables
+        else:
+            tables = metadata.tables.values()
+        collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
+        
+        for listener in metadata.ddl_listeners['before-drop']:
+            listener('before-drop', metadata, self.connection, tables=collection)
+        
+        for table in collection:
+            self.traverse_single(table)
+
+        for listener in metadata.ddl_listeners['after-drop']:
+            listener('after-drop', metadata, self.connection, tables=collection)
+
+    def _can_drop(self, table):
+        self.dialect.validate_identifier(table.name)
+        if table.schema:
+            self.dialect.validate_identifier(table.schema)
+        return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+    def visit_index(self, index):
+        self.connection.execute(schema.DropIndex(index))
+
+    def visit_table(self, table):
+        for listener in table.ddl_listeners['before-drop']:
+            listener('before-drop', table, self.connection)
+
+        for column in table.columns:
+            if column.default is not None:
+                self.traverse_single(column.default)
+
+        self.connection.execute(schema.DropTable(table))
+
+        for listener in table.ddl_listeners['after-drop']:
+            listener('after-drop', table, self.connection)
+
+    def visit_sequence(self, sequence):
+        if self.dialect.supports_sequences:
+            if ((not self.dialect.sequences_optional or
+                 not sequence.optional) and
+                (not self.checkfirst or
+                 self.dialect.has_sequence(self.connection, sequence.name))):
+                self.connection.execute(schema.DropSequence(sequence))
index 728b932a2e172532b2de1d04abc14a11776dd1a6..935d1e087d486303452c8a002ce5aff389076b9f 100644 (file)
@@ -13,36 +13,59 @@ as the base class for their own corresponding classes.
 """
 
 import re, random
-from sqlalchemy.engine import base
+from sqlalchemy.engine import base, reflection
 from sqlalchemy.sql import compiler, expression
-from sqlalchemy import exc
+from sqlalchemy import exc, types as sqltypes, util
 
 AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
                                re.I | re.UNICODE)
 
+
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""
 
-    name = 'default'
-    schemagenerator = compiler.SchemaGenerator
-    schemadropper = compiler.SchemaDropper
-    statement_compiler = compiler.DefaultCompiler
+    statement_compiler = compiler.SQLCompiler
+    ddl_compiler = compiler.DDLCompiler
+    type_compiler = compiler.GenericTypeCompiler
     preparer = compiler.IdentifierPreparer
     defaultrunner = base.DefaultRunner
     supports_alter = True
+
+    supports_sequences = False
+    sequences_optional = False
+    preexecute_autoincrement_sequences = False
+    postfetch_lastrowid = True
+    implicit_returning = False
+    
+    # Py3K
+    #supports_unicode_statements = True
+    #supports_unicode_binds = True
+    # Py2K
     supports_unicode_statements = False
+    supports_unicode_binds = False
+    # end Py2K
+
+    name = 'default'
     max_identifier_length = 9999
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
-    preexecute_pk_sequences = False
-    supports_pk_autoincrement = True
     dbapi_type_map = {}
     default_paramstyle = 'named'
-    supports_default_values = False 
+    supports_default_values = False
     supports_empty_insert = True
+    
+    # indicates symbol names are 
+    # UPPERCASEd if they are case insensitive
+    # within the database.
+    # if this is True, the methods normalize_name()
+    # and denormalize_name() must be provided.
+    requires_name_normalize = False
+    
+    reflection_options = ()
 
     def __init__(self, convert_unicode=False, assert_unicode=False,
-                 encoding='utf-8', paramstyle=None, dbapi=None, 
+                 encoding='utf-8', paramstyle=None, dbapi=None,
+                 implicit_returning=None,
                  label_length=None, **kwargs):
         self.convert_unicode = convert_unicode
         self.assert_unicode = assert_unicode
@@ -56,28 +79,58 @@ class DefaultDialect(base.Dialect):
             self.paramstyle = self.dbapi.paramstyle
         else:
             self.paramstyle = self.default_paramstyle
+        if implicit_returning is not None:
+            self.implicit_returning = implicit_returning
         self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
         self.identifier_preparer = self.preparer(self)
+        self.type_compiler = self.type_compiler(self)
+
         if label_length and label_length > self.max_identifier_length:
-            raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
+            raise exc.ArgumentError("Label length of %d is greater than this dialect's"
+                                    " maximum identifier length of %d" %
+                                    (label_length, self.max_identifier_length))
         self.label_length = label_length
-        self.description_encoding = getattr(self, 'description_encoding', encoding)
 
-    def type_descriptor(self, typeobj):
+        if not hasattr(self, 'description_encoding'):
+            self.description_encoding = getattr(self, 'description_encoding', encoding)
+
+        # Py3K
+        ## work around dialects that might change these values
+        #self.supports_unicode_statements = True
+        #self.supports_unicode_binds = True
+
+    def initialize(self, connection):
+        if hasattr(self, '_get_server_version_info'):
+            self.server_version_info = self._get_server_version_info(connection)
+        if hasattr(self, '_get_default_schema_name'):
+            self.default_schema_name = self._get_default_schema_name(connection)
+        
+    @classmethod
+    def type_descriptor(cls, typeobj):
         """Provide a database-specific ``TypeEngine`` object, given
         the generic object which comes from the types module.
 
-        Subclasses will usually use the ``adapt_type()`` method in the
-        types module to make this job easy."""
+        This method looks for a dictionary called
+        ``colspecs`` as a class or instance-level variable,
+        and passes on to ``types.adapt_type()``.
 
-        if type(typeobj) is type:
-            typeobj = typeobj()
-        return typeobj
+        """
+        return sqltypes.adapt_type(typeobj, cls.colspecs)
+
+    def reflecttable(self, connection, table, include_columns):
+        insp = reflection.Inspector.from_engine(connection)
+        return insp.reflecttable(table, include_columns)
 
     def validate_identifier(self, ident):
         if len(ident) > self.max_identifier_length:
-            raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length))
-        
+            raise exc.IdentifierError(
+                "Identifier '%s' exceeds maximum length of %d characters" % 
+                (ident, self.max_identifier_length)
+            )
+
+    def connect(self, *cargs, **cparams):
+        return self.dbapi.connect(*cargs, **cparams)
+
     def do_begin(self, connection):
         """Implementations might want to put logic here for turning
         autocommit on/off, etc.
@@ -103,7 +156,8 @@ class DefaultDialect(base.Dialect):
         """Create a random two-phase transaction ID.
 
         This id will be passed to do_begin_twophase(), do_rollback_twophase(),
-        do_commit_twophase().  Its format is unspecified."""
+        do_commit_twophase().  Its format is unspecified.
+        """
 
         return "_sa_%032x" % random.randint(0, 2 ** 128)
 
@@ -127,13 +181,30 @@ class DefaultDialect(base.Dialect):
 
 
 class DefaultExecutionContext(base.ExecutionContext):
-    def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
+    
+    def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
         self.dialect = dialect
         self._connection = self.root_connection = connection
-        self.compiled = compiled
         self.engine = connection.engine
 
-        if compiled is not None:
+        if compiled_ddl is not None:
+            self.compiled = compiled = compiled_ddl
+            if not dialect.supports_unicode_statements:
+                self.statement = unicode(compiled).encode(self.dialect.encoding)
+            else:
+                self.statement = unicode(compiled)
+            self.isinsert = self.isupdate = self.isdelete = self.executemany = False
+            self.should_autocommit = True
+            self.result_map = None
+            self.cursor = self.create_cursor()
+            self.compiled_parameters = []
+            if self.dialect.positional:
+                self.parameters = [()]
+            else:
+                self.parameters = [{}]
+        elif compiled_sql is not None:
+            self.compiled = compiled = compiled_sql
+
             # compiled clauseelement.  process bind params, process table defaults,
             # track collections used by ResultProxy to target and process results
 
@@ -156,6 +227,7 @@ class DefaultExecutionContext(base.ExecutionContext):
 
             self.isinsert = compiled.isinsert
             self.isupdate = compiled.isupdate
+            self.isdelete = compiled.isdelete
             self.should_autocommit = compiled.statement._autocommit
             if isinstance(compiled.statement, expression._TextClause):
                 self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement)
@@ -173,31 +245,43 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.parameters = self.__convert_compiled_params(self.compiled_parameters)
 
         elif statement is not None:
-            # plain text statement.
-            self.result_map = None
+            # plain text statement
+            self.result_map = self.compiled = None
             self.parameters = self.__encode_param_keys(parameters)
             self.executemany = len(parameters) > 1
             if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
                 self.statement = statement.encode(self.dialect.encoding)
             else:
                 self.statement = statement
-            self.isinsert = self.isupdate = False
+            self.isinsert = self.isupdate = self.isdelete = False
             self.cursor = self.create_cursor()
             self.should_autocommit = self.should_autocommit_text(statement)
         else:
             # no statement. used for standalone ColumnDefault execution.
-            self.statement = None
-            self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
+            self.statement = self.compiled = None
+            self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False
             self.cursor = self.create_cursor()
-
+    
+    @util.memoized_property
+    def _is_explicit_returning(self):
+        return self.compiled and \
+            getattr(self.compiled.statement, '_returning', False)
+    
+    @util.memoized_property
+    def _is_implicit_returning(self):
+        return self.compiled and \
+            bool(self.compiled.returning) and \
+            not self.compiled.statement._returning
+    
     @property
     def connection(self):
         return self._connection._branch()
 
     def __encode_param_keys(self, params):
-        """apply string encoding to the keys of dictionary-based bind parameters.
+        """Apply string encoding to the keys of dictionary-based bind parameters.
 
-        This is only used executing textual, non-compiled SQL expressions."""
+        This is only used executing textual, non-compiled SQL expressions.
+        """
 
         if self.dialect.positional or self.dialect.supports_unicode_statements:
             if params:
@@ -216,7 +300,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             return [proc(d) for d in params] or [{}]
 
     def __convert_compiled_params(self, compiled_parameters):
-        """convert the dictionary of bind parameter values into a dict or list
+        """Convert the dictionary of bind parameter values into a dict or list
         to be sent to the DBAPI's execute() or executemany() method.
         """
 
@@ -263,26 +347,69 @@ class DefaultExecutionContext(base.ExecutionContext):
     def post_exec(self):
         pass
     
+    def get_lastrowid(self):
+        """return self.cursor.lastrowid, or equivalent, after an INSERT.
+        
+        This may involve calling special cursor functions,
+        issuing a new SELECT on the cursor (or a new one),
+        or returning a stored value that was
+        calculated within post_exec().
+        
+        This function will only be called for dialects
+        which support "implicit" primary key generation,
+        keep preexecute_autoincrement_sequences set to False,
+        and when no explicit id value was bound to the
+        statement.
+        
+        The function is called once, directly after 
+        post_exec() and before the transaction is committed
+        or ResultProxy is generated.   If the post_exec()
+        method assigns a value to `self._lastrowid`, the
+        value is used in place of calling get_lastrowid().
+        
+        Note that this method is *not* equivalent to the
+        ``lastrowid`` method on ``ResultProxy``, which is a
+        direct proxy to the DBAPI ``lastrowid`` accessor
+        in all cases.
+        
+        """
+        
+        return self.cursor.lastrowid
+
     def handle_dbapi_exception(self, e):
         pass
 
     def get_result_proxy(self):
         return base.ResultProxy(self)
+    
+    @property
+    def rowcount(self):
+        return self.cursor.rowcount
 
-    def get_rowcount(self):
-        if hasattr(self, '_rowcount'):
-            return self._rowcount
-        else:
-            return self.cursor.rowcount
-        
     def supports_sane_rowcount(self):
         return self.dialect.supports_sane_rowcount
 
     def supports_sane_multi_rowcount(self):
         return self.dialect.supports_sane_multi_rowcount
-
-    def last_inserted_ids(self):
-        return self._last_inserted_ids
+    
+    def post_insert(self):
+        if self.dialect.postfetch_lastrowid and \
+            (not len(self._inserted_primary_key) or \
+                        None in self._inserted_primary_key):
+            
+            table = self.compiled.statement.table
+            lastrowid = self.get_lastrowid()
+            self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
+                for c, v in zip(table.primary_key, self._inserted_primary_key)
+            ]
+            
+    def _fetch_implicit_returning(self, resultproxy):
+        table = self.compiled.statement.table
+        row = resultproxy.first()
+
+        self._inserted_primary_key = [v is not None and v or row[c] 
+            for c, v in zip(table.primary_key, self._inserted_primary_key)
+        ]
 
     def last_inserted_params(self):
         return self._last_inserted_params
@@ -293,12 +420,15 @@ class DefaultExecutionContext(base.ExecutionContext):
     def lastrow_has_defaults(self):
         return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
 
-    def set_input_sizes(self):
+    def set_input_sizes(self, translate=None, exclude_types=None):
         """Given a cursor and ClauseParameters, call the appropriate
         style of ``setinputsizes()`` on the cursor, using DB-API types
         from the bind parameter's ``TypeEngine`` objects.
         """
 
+        if not hasattr(self.compiled, 'bind_names'):
+            return
+
         types = dict(
                 (self.compiled.bind_names[bindparam], bindparam.type)
                  for bindparam in self.compiled.bind_names)
@@ -308,7 +438,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             for key in self.compiled.positiontup:
                 typeengine = types[key]
                 dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
-                if dbtype is not None:
+                if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
                     inputsizes.append(dbtype)
             try:
                 self.cursor.setinputsizes(*inputsizes)
@@ -320,7 +450,9 @@ class DefaultExecutionContext(base.ExecutionContext):
             for key in self.compiled.bind_names.values():
                 typeengine = types[key]
                 dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
-                if dbtype is not None:
+                if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
+                    if translate:
+                        key = translate.get(key, key)
                     inputsizes[key.encode(self.dialect.encoding)] = dbtype
             try:
                 self.cursor.setinputsizes(**inputsizes)
@@ -329,8 +461,9 @@ class DefaultExecutionContext(base.ExecutionContext):
                 raise
 
     def __process_defaults(self):
-        """generate default values for compiled insert/update statements,
-        and generate last_inserted_ids() collection."""
+        """Generate default values for compiled insert/update statements,
+        and generate inserted_primary_key collection.
+        """
 
         if self.executemany:
             if len(self.compiled.prefetch):
@@ -364,7 +497,8 @@ class DefaultExecutionContext(base.ExecutionContext):
                     compiled_parameters[c.key] = val
 
             if self.isinsert:
-                self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key]
+                self._inserted_primary_key = [compiled_parameters.get(c.key, None) 
+                                            for c in self.compiled.statement.table.primary_key]
                 self._last_inserted_params = compiled_parameters
             else:
                 self._last_updated_params = compiled_parameters
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
new file mode 100644 (file)
index 0000000..173e0fa
--- /dev/null
@@ -0,0 +1,361 @@
+"""Provides an abstraction for obtaining database schema information.
+
+Usage Notes:
+
+Here are some general conventions when accessing the low level inspector
+methods such as get_table_names, get_columns, etc.
+
+1. Inspector methods return lists of dicts in most cases for the following
+   reasons:
+
+   * They're both standard types that can be serialized.
+   * Using a dict instead of a tuple allows easy expansion of attributes.
+   * Using a list for the outer structure maintains order and is easy to work
+     with (e.g. list comprehension [d['name'] for d in cols]).
+
+2. Records that contain a name, such as the column name in a column record
+   use the key 'name'. So for most return values, each record will have a
+   'name' attribute..
+"""
+
+import sqlalchemy
+from sqlalchemy import exc, sql
+from sqlalchemy import util
+from sqlalchemy.types import TypeEngine
+from sqlalchemy import schema as sa_schema
+
+
+@util.decorator
+def cache(fn, self, con, *args, **kw):
+    info_cache = kw.get('info_cache', None)
+    if info_cache is None:
+        return fn(self, con, *args, **kw)
+    key = (
+            fn.__name__, 
+            tuple(a for a in args if isinstance(a, basestring)), 
+            tuple((k, v) for k, v in kw.iteritems() if isinstance(v, basestring))
+        )
+    ret = info_cache.get(key)
+    if ret is None:
+        ret = fn(self, con, *args, **kw)
+        info_cache[key] = ret
+    return ret
+
+
+class Inspector(object):
+    """Performs database schema inspection.
+
+    The Inspector acts as a proxy to the dialects' reflection methods and
+    provides higher level functions for accessing database schema information.
+    """
+
+    def __init__(self, conn):
+        """Initialize the instance.
+
+        :param conn: a :class:`~sqlalchemy.engine.base.Connectable`
+        """
+
+        self.conn = conn
+        # set the engine
+        if hasattr(conn, 'engine'):
+            self.engine = conn.engine
+        else:
+            self.engine = conn
+        self.dialect = self.engine.dialect
+        self.info_cache = {}
+
+    @classmethod
+    def from_engine(cls, engine):
+        if hasattr(engine.dialect, 'inspector'):
+            return engine.dialect.inspector(engine)
+        return Inspector(engine)
+
+    @property
+    def default_schema_name(self):
+        return self.dialect.get_default_schema_name(self.conn)
+
+    def get_schema_names(self):
+        """Return all schema names.
+        """
+
+        if hasattr(self.dialect, 'get_schema_names'):
+            return self.dialect.get_schema_names(self.conn,
+                                                    info_cache=self.info_cache)
+        return []
+
+    def get_table_names(self, schema=None, order_by=None):
+        """Return all table names in `schema`.
+
+        :param schema: Optional, retrieve names from a non-default schema.
+        :param order_by: Optional, may be the string "foreign_key" to sort
+                         the result on foreign key dependencies.
+
+        This should probably not return view names or maybe it should return
+        them with an indicator t or v.
+        """
+
+        if hasattr(self.dialect, 'get_table_names'):
+            tnames = self.dialect.get_table_names(self.conn,
+            schema,
+                                                    info_cache=self.info_cache)
+        else:
+            tnames = self.engine.table_names(schema)
+        if order_by == 'foreign_key':
+            ordered_tnames = tnames[:]
+            # Order based on foreign key dependencies.
+            for tname in tnames:
+                table_pos = tnames.index(tname)
+                fkeys = self.get_foreign_keys(tname, schema)
+                for fkey in fkeys:
+                    rtable = fkey['referred_table']
+                    if rtable in ordered_tnames:
+                        ref_pos = ordered_tnames.index(rtable)
+                        # Make sure it's lower in the list than anything it
+                        # references.
+                        if table_pos > ref_pos:
+                            ordered_tnames.pop(table_pos) # rtable moves up 1
+                            # insert just below rtable
+                            ordered_tnames.index(ref_pos, tname)
+            tnames = ordered_tnames
+        return tnames
+
+    def get_table_options(self, table_name, schema=None, **kw):
+        if hasattr(self.dialect, 'get_table_options'):
+            return self.dialect.get_table_options(self.conn, table_name, schema,
+                                                  info_cache=self.info_cache,
+                                                  **kw)
+        return {}
+
+    def get_view_names(self, schema=None):
+        """Return all view names in `schema`.
+
+        :param schema: Optional, retrieve names from a non-default schema.
+        """
+
+        return self.dialect.get_view_names(self.conn, schema,
+                                                  info_cache=self.info_cache)
+
+    def get_view_definition(self, view_name, schema=None):
+        """Return definition for `view_name`.
+
+        :param schema: Optional, retrieve names from a non-default schema.
+        """
+
+        return self.dialect.get_view_definition(
+            self.conn, view_name, schema, info_cache=self.info_cache)
+
+    def get_columns(self, table_name, schema=None, **kw):
+        """Return information about columns in `table_name`.
+
+        Given a string `table_name` and an optional string `schema`, return
+        column information as a list of dicts with these keys:
+
+        name
+          the column's name
+
+        type
+          :class:`~sqlalchemy.types.TypeEngine`
+
+        nullable
+          boolean
+
+        default
+          the column's default value
+
+        attrs
+          dict containing optional column attributes
+        """
+
+        col_defs = self.dialect.get_columns(self.conn, table_name, schema,
+                                            info_cache=self.info_cache,
+                                            **kw)
+        for col_def in col_defs:
+            # make this easy and only return instances for coltype
+            coltype = col_def['type']
+            if not isinstance(coltype, TypeEngine):
+                col_def['type'] = coltype()
+        return col_defs
+
+    def get_primary_keys(self, table_name, schema=None, **kw):
+        """Return information about primary keys in `table_name`.
+
+        Given a string `table_name`, and an optional string `schema`, return
+        primary key information as a list of column names.
+        """
+
+        pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
+                                              info_cache=self.info_cache,
+                                              **kw)
+
+        return pkeys
+
+    def get_foreign_keys(self, table_name, schema=None, **kw):
+        """Return information about foreign_keys in `table_name`.
+
+        Given a string `table_name`, and an optional string `schema`, return
+        foreign key information as a list of dicts with these keys:
+
+        constrained_columns
+          a list of column names that make up the foreign key
+
+        referred_schema
+          the name of the referred schema
+
+        referred_table
+          the name of the referred table
+
+        referred_columns
+          a list of column names in the referred table that correspond to
+          constrained_columns
+        """
+
+        fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
+                                                info_cache=self.info_cache,
+                                                **kw)
+        return fk_defs
+
+    def get_indexes(self, table_name, schema=None):
+        """Return information about indexes in `table_name`.
+
+        Given a string `table_name` and an optional string `schema`, return
+        index information as a list of dicts with these keys:
+
+        name
+          the index's name
+
+        column_names
+          list of column names in order
+
+        unique
+          boolean
+        """
+
+        indexes = self.dialect.get_indexes(self.conn, table_name,
+                                                  schema,
+                                            info_cache=self.info_cache)
+        return indexes
+
+    def reflecttable(self, table, include_columns):
+
+        dialect = self.conn.dialect
+
+        # MySQL dialect does this.  Applicable with other dialects?
+        if hasattr(dialect, '_connection_charset') \
+                                        and hasattr(dialect, '_adjust_casing'):
+            charset = dialect._connection_charset
+            dialect._adjust_casing(table)
+
+        # table attributes we might need.
+        reflection_options = dict(
+            (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
+
+        schema = table.schema
+        table_name = table.name
+
+        # apply table options
+        tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
+        if tbl_opts:
+            table.kwargs.update(tbl_opts)
+
+        # table.kwargs will need to be passed to each reflection method.  Make
+        # sure keywords are strings.
+        tblkw = table.kwargs.copy()
+        for (k, v) in tblkw.items():
+            del tblkw[k]
+            tblkw[str(k)] = v
+
+        # Py2K
+        if isinstance(schema, str):
+            schema = schema.decode(dialect.encoding)
+        if isinstance(table_name, str):
+            table_name = table_name.decode(dialect.encoding)
+        # end Py2K
+
+        # columns
+        found_table = False
+        for col_d in self.get_columns(table_name, schema, **tblkw):
+            found_table = True
+            name = col_d['name']
+            if include_columns and name not in include_columns:
+                continue
+
+            coltype = col_d['type']
+            col_kw = {
+                'nullable':col_d['nullable'],
+            }
+            if 'autoincrement' in col_d:
+                col_kw['autoincrement'] = col_d['autoincrement']
+            
+            colargs = []
+            if col_d.get('default') is not None:
+                # the "default" value is assumed to be a literal SQL expression,
+                # so is wrapped in text() so that no quoting occurs on re-issuance.
+                colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
+                
+            if 'sequence' in col_d:
+                # TODO: whos using this ?
+                seq = col_d['sequence']
+                sequence = sa_schema.Sequence(seq['name'], 1, 1)
+                if 'start' in seq:
+                    sequence.start = seq['start']
+                if 'increment' in seq:
+                    sequence.increment = seq['increment']
+                colargs.append(sequence)
+                
+            col = sa_schema.Column(name, coltype, *colargs, **col_kw)
+            table.append_column(col)
+
+        if not found_table:
+            raise exc.NoSuchTableError(table.name)
+
+        # Primary keys
+        primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
+            table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
+            if pk in table.c
+        ])
+
+        table.append_constraint(primary_key_constraint)
+
+        # Foreign keys
+        fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
+        for fkey_d in fkeys:
+            conname = fkey_d['name']
+            constrained_columns = fkey_d['constrained_columns']
+            referred_schema = fkey_d['referred_schema']
+            referred_table = fkey_d['referred_table']
+            referred_columns = fkey_d['referred_columns']
+            refspec = []
+            if referred_schema is not None:
+                sa_schema.Table(referred_table, table.metadata,
+                                autoload=True, schema=referred_schema,
+                                autoload_with=self.conn,
+                                **reflection_options
+                                )
+                for column in referred_columns:
+                    refspec.append(".".join(
+                        [referred_schema, referred_table, column]))
+            else:
+                sa_schema.Table(referred_table, table.metadata, autoload=True,
+                                autoload_with=self.conn,
+                                **reflection_options
+                                )
+                for column in referred_columns:
+                    refspec.append(".".join([referred_table, column]))
+            table.append_constraint(
+                sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
+                                               conname, link_to_name=True))
+        # Indexes 
+        indexes = self.get_indexes(table_name, schema)
+        for index_d in indexes:
+            name = index_d['name']
+            columns = index_d['column_names']
+            unique = index_d['unique']
+            flavor = index_d.get('type', 'unknown type')
+            if include_columns and \
+                            not set(columns).issubset(include_columns):
+                util.warn(
+                    "Omitting %s KEY for (%s), key covers omitted columns." %
+                    (flavor, ', '.join(columns)))
+                continue
+            sa_schema.Index(name, *[table.columns[c] for c in columns], 
+                         **dict(unique=unique))
index fa608df65ee639f4d18edb93a252ac8bd869f202..ff62b265baf3207806ec566441732ca2bd78f28d 100644 (file)
@@ -6,31 +6,26 @@ underlying behavior for the "strategy" keyword argument available on
 ``plain``, ``threadlocal``, and ``mock``.
 
 New strategies can be added via new ``EngineStrategy`` classes.
-
 """
+
 from operator import attrgetter
 
 from sqlalchemy.engine import base, threadlocal, url
 from sqlalchemy import util, exc
 from sqlalchemy import pool as poollib
 
-
 strategies = {}
 
+
 class EngineStrategy(object):
     """An adaptor that processes input arguements and produces an Engine.
 
     Provides a ``create`` method that receives input arguments and
     produces an instance of base.Engine or a subclass.
+    
     """
 
-    def __init__(self, name):
-        """Construct a new EngineStrategy object.
-
-        Sets it in the list of available strategies under this name.
-        """
-
-        self.name = name
+    def __init__(self):
         strategies[self.name] = self
 
     def create(self, *args, **kwargs):
@@ -38,9 +33,12 @@ class EngineStrategy(object):
 
         raise NotImplementedError()
 
+
 class DefaultEngineStrategy(EngineStrategy):
     """Base class for built-in stratgies."""
 
+    pool_threadlocal = False
+    
     def create(self, name_or_url, **kwargs):
         # create url.URL object
         u = url.make_url(name_or_url)
@@ -75,9 +73,15 @@ class DefaultEngineStrategy(EngineStrategy):
         if pool is None:
             def connect():
                 try:
-                    return dbapi.connect(*cargs, **cparams)
+                    return dialect.connect(*cargs, **cparams)
                 except Exception, e:
-                    raise exc.DBAPIError.instance(None, None, e)
+                    # Py3K
+                    #raise exc.DBAPIError.instance(None, None, e) from e
+                    # Py2K
+                    import sys
+                    raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
+                    # end Py2K
+                    
             creator = kwargs.pop('creator', connect)
 
             poolclass = (kwargs.pop('poolclass', None) or
@@ -94,7 +98,7 @@ class DefaultEngineStrategy(EngineStrategy):
                 tk = translate.get(k, k)
                 if tk in kwargs:
                     pool_args[k] = kwargs.pop(tk)
-            pool_args.setdefault('use_threadlocal', self.pool_threadlocal())
+            pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
             pool = poolclass(creator, **pool_args)
         else:
             if isinstance(pool, poollib._DBProxy):
@@ -103,12 +107,14 @@ class DefaultEngineStrategy(EngineStrategy):
                 pool = pool
 
         # create engine.
-        engineclass = self.get_engine_cls()
+        engineclass = self.engine_cls
         engine_args = {}
         for k in util.get_cls_kwargs(engineclass):
             if k in kwargs:
                 engine_args[k] = kwargs.pop(k)
 
+        _initialize = kwargs.pop('_initialize', True)
+        
         # all kwargs should be consumed
         if kwargs:
             raise TypeError(
@@ -119,39 +125,38 @@ class DefaultEngineStrategy(EngineStrategy):
                                     dialect.__class__.__name__,
                                     pool.__class__.__name__,
                                     engineclass.__name__))
-        return engineclass(pool, dialect, u, **engine_args)
+                                    
+        engine = engineclass(pool, dialect, u, **engine_args)
 
-    def pool_threadlocal(self):
-        raise NotImplementedError()
+        if _initialize:
+            # some unit tests pass through _initialize=False
+            # to help mock engines work
+            class OnInit(object):
+                def first_connect(self, conn, rec):
+                    c = base.Connection(engine, connection=conn)
+                    dialect.initialize(c)
+            pool._on_first_connect.insert(0, OnInit())
 
-    def get_engine_cls(self):
-        raise NotImplementedError()
+        dialect.visit_pool(pool)
 
-class PlainEngineStrategy(DefaultEngineStrategy):
-    """Strategy for configuring a regular Engine."""
+        return engine
 
-    def __init__(self):
-        DefaultEngineStrategy.__init__(self, 'plain')
-
-    def pool_threadlocal(self):
-        return False
 
-    def get_engine_cls(self):
-        return base.Engine
+class PlainEngineStrategy(DefaultEngineStrategy):
+    """Strategy for configuring a regular Engine."""
 
+    name = 'plain'
+    engine_cls = base.Engine
+    
 PlainEngineStrategy()
 
+
 class ThreadLocalEngineStrategy(DefaultEngineStrategy):
     """Strategy for configuring an Engine with thredlocal behavior."""
-
-    def __init__(self):
-        DefaultEngineStrategy.__init__(self, 'threadlocal')
-
-    def pool_threadlocal(self):
-        return True
-
-    def get_engine_cls(self):
-        return threadlocal.TLEngine
+    
+    name = 'threadlocal'
+    pool_threadlocal = True
+    engine_cls = threadlocal.TLEngine
 
 ThreadLocalEngineStrategy()
 
@@ -161,11 +166,11 @@ class MockEngineStrategy(EngineStrategy):
 
     Produces a single mock Connectable object which dispatches
     statement execution to a passed-in function.
+    
     """
 
-    def __init__(self):
-        EngineStrategy.__init__(self, 'mock')
-
+    name = 'mock'
+    
     def create(self, name_or_url, executor, **kwargs):
         # create url.URL object
         u = url.make_url(name_or_url)
@@ -201,11 +206,14 @@ class MockEngineStrategy(EngineStrategy):
 
         def create(self, entity, **kwargs):
             kwargs['checkfirst'] = False
-            self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity)
+            from sqlalchemy.engine import ddl
+            
+            ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
 
         def drop(self, entity, **kwargs):
             kwargs['checkfirst'] = False
-            self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity)
+            from sqlalchemy.engine import ddl
+            ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
 
         def execute(self, object, *multiparams, **params):
             raise NotImplementedError()
index 8ad14ad35f096fe7b9432d28c19321649d512948..27d857623e37f200c220b31c34794f8aa4a27c81 100644 (file)
@@ -8,6 +8,7 @@ invoked automatically when the threadlocal engine strategy is used.
 from sqlalchemy import util
 from sqlalchemy.engine import base
 
+
 class TLSession(object):
     def __init__(self, engine):
         self.engine = engine
@@ -17,7 +18,8 @@ class TLSession(object):
         try:
             return self.__transaction._increment_connect()
         except AttributeError:
-            return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
+            return self.engine.TLConnection(self, self.engine.pool.connect(),
+                                            close_with_result=close_with_result)
 
     def reset(self):
         try:
index 5c8e68ce45a6f15199b8a7303ede237ac90051ba..b0e21f5f7216cd77667eb07ec3b92291a5776354 100644 (file)
@@ -20,9 +20,9 @@ class URL(object):
     format of the URL is an RFC-1738-style string.
 
     All initialization parameters are available as public attributes.
-    
-    :param drivername: the name of the database backend.  
-      This name will correspond to a module in sqlalchemy/databases 
+
+    :param drivername: the name of the database backend.
+      This name will correspond to a module in sqlalchemy/databases
       or a third party plug-in.
 
     :param username: The user name.
@@ -35,12 +35,13 @@ class URL(object):
 
     :param database: The database name.
 
-    :param query: A dictionary of options to be passed to the 
+    :param query: A dictionary of options to be passed to the
       dialect and/or the DBAPI upon connect.
-        
+
     """
 
-    def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None):
+    def __init__(self, drivername, username=None, password=None,
+                 host=None, port=None, database=None, query=None):
         self.drivername = drivername
         self.username = username
         self.password = password
@@ -70,10 +71,10 @@ class URL(object):
             keys.sort()
             s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
         return s
-    
+
     def __hash__(self):
         return hash(str(self))
-    
+
     def __eq__(self, other):
         return \
             isinstance(other, URL) and \
@@ -83,12 +84,22 @@ class URL(object):
             self.host == other.host and \
             self.database == other.database and \
             self.query == other.query
-            
+
     def get_dialect(self):
-        """Return the SQLAlchemy database dialect class corresponding to this URL's driver name."""
-        
+        """Return the SQLAlchemy database dialect class corresponding
+        to this URL's driver name.
+        """
+
         try:
-            module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+            if '+' in self.drivername:
+                dialect, driver = self.drivername.split('+')
+            else:
+                dialect, driver = self.drivername, 'base'
+
+            module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
+            module = getattr(module, dialect)
+            module = getattr(module, driver)
+
             return module.dialect
         except ImportError:
             if sys.exc_info()[2].tb_next is None:
@@ -97,7 +108,7 @@ class URL(object):
                     if res.name == self.drivername:
                         return res.load()
             raise
-  
+
     def translate_connect_args(self, names=[], **kw):
         """Translate url attributes into a dictionary of connection arguments.
 
@@ -107,10 +118,9 @@ class URL(object):
         from the final dictionary.
 
         :param \**kw: Optional, alternate key names for url attributes.
-        
+
         :param names: Deprecated.  Same purpose as the keyword-based alternate names,
             but correlates the name to the original positionally.
-        
         """
 
         translated = {}
@@ -131,8 +141,8 @@ def make_url(name_or_url):
 
     The given string is parsed according to the RFC 1738 spec.  If an
     existing URL object is passed, just returns the object.
-    
     """
+
     if isinstance(name_or_url, basestring):
         return _parse_rfc1738_args(name_or_url)
     else:
@@ -140,7 +150,7 @@ def make_url(name_or_url):
 
 def _parse_rfc1738_args(name):
     pattern = re.compile(r'''
-            (?P<name>\w+)://
+            (?P<name>[\w\+]+)://
             (?:
                 (?P<username>[^:/]*)
                 (?::(?P<password>[^/]*))?
@@ -160,8 +170,10 @@ def _parse_rfc1738_args(name):
             tokens = components['database'].split('?', 2)
             components['database'] = tokens[0]
             query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
+            # Py2K
             if query is not None:
                 query = dict((k.encode('ascii'), query[k]) for k in query)
+            # end Py2K
         else:
             query = None
         components['query'] = query
index ce130ce3c2af9752c0c8146d0919aed89127f052..f1678743d9397cef9362e8eecfb58b792cca8981 100644 (file)
@@ -103,6 +103,7 @@ class DBAPIError(SQLAlchemyError):
 
     """
 
+    @classmethod
     def instance(cls, statement, params, orig, connection_invalidated=False):
         # Don't ever wrap these, just return them directly as if
         # DBAPIError didn't exist.
@@ -115,7 +116,6 @@ class DBAPIError(SQLAlchemyError):
                 cls = glob[name]
 
         return cls(statement, params, orig, connection_invalidated)
-    instance = classmethod(instance)
 
     def __init__(self, statement, params, orig, connection_invalidated=False):
         try:
index 0e3db00e02aeeec2e34143d22684e33ce6eda446..05df8d2be6b7f7fa7c5149e54af12aeb3e458722 100644 (file)
@@ -33,7 +33,7 @@ Produces::
 Compilers can also be made dialect-specific.  The appropriate compiler will be invoked
 for the dialect in use::
 
-    from sqlalchemy.schema import DDLElement  # this is a SQLA 0.6 construct
+    from sqlalchemy.schema import DDLElement
 
     class AlterColumn(DDLElement):
 
@@ -45,16 +45,16 @@ for the dialect in use::
     def visit_alter_column(element, compiler, **kw):
         return "ALTER COLUMN %s ..." % element.column.name
 
-    @compiles(AlterColumn, 'postgres')
+    @compiles(AlterColumn, 'postgresql')
     def visit_alter_column(element, compiler, **kw):
         return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name)
 
-The second ``visit_alter_table`` will be invoked when any ``postgres`` dialect is used.
+The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used.
 
 The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` object
 in use.  This object can be inspected for any information about the in-progress 
 compilation, including ``compiler.dialect``, ``compiler.statement`` etc.
-The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler` (DDLCompiler is 0.6. only)
+The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler`
 both include a ``process()`` method which can be used for compilation of embedded attributes::
 
     class InsertFromSelect(ClauseElement):
index 07974caccea3b6082f7b073dfeb2d451a3c7874d..c37211ac3dbb6f430d9f8f4523516dff130e4bb2 100644 (file)
@@ -396,7 +396,7 @@ only intended as an optional syntax for the regular usage of mappers and Table
 objects.  A typical application setup using :func:`~sqlalchemy.orm.scoped_session` might look
 like::
 
-    engine = create_engine('postgres://scott:tiger@localhost/test')
+    engine = create_engine('postgresql://scott:tiger@localhost/test')
     Session = scoped_session(sessionmaker(autocommit=False,
                                           autoflush=False,
                                           bind=engine))
index a5d60bf82e2bedcd7ecaa7d8148e05ec7777a21e..8e63ed1c29c383ce3c5b538b0024c60d8b9eca07 100644 (file)
@@ -240,7 +240,8 @@ class OrderingList(list):
     _raw_append = collection.adds(1)(_raw_append)
 
     def insert(self, index, entity):
-        self[index:index] = [entity]
+        super(OrderingList, self).insert(index, entity)
+        self._reorder()
 
     def remove(self, entity):
         super(OrderingList, self).remove(entity)
@@ -253,7 +254,15 @@ class OrderingList(list):
 
     def __setitem__(self, index, entity):
         if isinstance(index, slice):
-            for i in range(index.start or 0, index.stop or 0, index.step or 1):
+            step = index.step or 1
+            start = index.start or 0
+            if start < 0:
+                start += len(self)
+            stop = index.stop or len(self)
+            if stop < 0:
+                stop += len(self)
+            
+            for i in xrange(start, stop, step):
                 self.__setitem__(i, entity[i])
         else:
             self._order_entity(index, entity, True)
@@ -263,6 +272,7 @@ class OrderingList(list):
         super(OrderingList, self).__delitem__(index)
         self._reorder()
 
+    # Py2K
     def __setslice__(self, start, end, values):
         super(OrderingList, self).__setslice__(start, end, values)
         self._reorder()
@@ -270,7 +280,8 @@ class OrderingList(list):
     def __delslice__(self, start, end):
         super(OrderingList, self).__delslice__(start, end)
         self._reorder()
-
+    # end Py2K
+    
     for func_name, func in locals().items():
         if (util.callable(func) and func.func_name == func_name and
             not func.__doc__ and hasattr(list, func_name)):
index b62ee0ce6422e42a0eab0df98ae9840dbfecc562..fd456e385fae6070368fe1210ac4f21dba10047b 100644 (file)
@@ -43,10 +43,26 @@ from sqlalchemy.engine import Engine
 from sqlalchemy.util import pickle
 import re
 import base64
-from cStringIO import StringIO
+# Py3K
+#from io import BytesIO as byte_buffer
+# Py2K
+from cStringIO import StringIO as byte_buffer
+# end Py2K
+
+# Py3K
+#def b64encode(x):
+#    return base64.b64encode(x).decode('ascii')
+#def b64decode(x):
+#    return base64.b64decode(x.encode('ascii'))
+# Py2K
+b64encode = base64.b64encode
+b64decode = base64.b64decode
+# end Py2K
 
 __all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
 
+
+
 def Serializer(*args, **kw):
     pickler = pickle.Pickler(*args, **kw)
         
@@ -55,9 +71,9 @@ def Serializer(*args, **kw):
         if isinstance(obj, QueryableAttribute):
             cls = obj.impl.class_
             key = obj.impl.key
-            id = "attribute:" + key + ":" + base64.b64encode(pickle.dumps(cls))
+            id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls))
         elif isinstance(obj, Mapper) and not obj.non_primary:
-            id = "mapper:" + base64.b64encode(pickle.dumps(obj.class_))
+            id = "mapper:" + b64encode(pickle.dumps(obj.class_))
         elif isinstance(obj, Table):
             id = "table:" + str(obj)
         elif isinstance(obj, Column) and isinstance(obj.table, Table):
@@ -96,10 +112,10 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
             type_, args = m.group(1, 2)
             if type_ == 'attribute':
                 key, clsarg = args.split(":")
-                cls = pickle.loads(base64.b64decode(clsarg))
+                cls = pickle.loads(b64decode(clsarg))
                 return getattr(cls, key)
             elif type_ == "mapper":
-                cls = pickle.loads(base64.b64decode(args))
+                cls = pickle.loads(b64decode(args))
                 return class_mapper(cls)
             elif type_ == "table":
                 return metadata.tables[args]
@@ -116,13 +132,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
     return unpickler
 
 def dumps(obj):
-    buf = StringIO()
+    buf = byte_buffer()
     pickler = Serializer(buf)
     pickler.dump(obj)
     return buf.getvalue()
     
 def loads(data, metadata=None, scoped_session=None, engine=None):
-    buf = StringIO(data)
+    buf = byte_buffer(data)
     unpickler = Deserializer(buf, metadata, scoped_session, engine)
     return unpickler.load()
     
index b3f2de743ef5c94e91b2162a1920c9eca0f97e21..6eef4657c3031b0f2bfc0c8108e75c28fa4d8d70 100644 (file)
@@ -86,7 +86,7 @@ Full query documentation
 Get, filter, filter_by, order_by, limit, and the rest of the
 query methods are explained in detail in the `SQLAlchemy documentation`__.
 
-__ http://www.sqlalchemy.org/docs/04/ormtutorial.html#datamapping_querying
+__ http://www.sqlalchemy.org/docs/05/ormtutorial.html#datamapping_querying
 
 
 Modifying objects
@@ -447,9 +447,11 @@ def _selectable_name(selectable):
 def class_for_table(selectable, **mapper_kwargs):
     selectable = expression._clause_element_as_expr(selectable)
     mapname = 'Mapped' + _selectable_name(selectable)
+    # Py2K
     if isinstance(mapname, unicode): 
         engine_encoding = selectable.metadata.bind.dialect.encoding 
         mapname = mapname.encode(engine_encoding)
+    # end Py2K
     if isinstance(selectable, Table):
         klass = TableClassType(mapname, (object,), {})
     else:
@@ -543,10 +545,14 @@ class SqlSoup:
     def entity(self, attr, schema=None):
         try:
             t = self._cache[attr]
-        except KeyError:
+        except KeyError, ke:
             table = Table(attr, self._metadata, autoload=True, schema=schema or self.schema)
             if not table.primary_key.columns:
+                # Py3K
+                #raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys()))) from ke
+                # Py2K
                 raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys())))
+                # end Py2K
             if table.columns:
                 t = class_for_table(table)
             else:
index dfceffe4451a0440d6d6b89a9a8c3adce9e4fb7f..e4a9adee1fb6351d5100632ff45d22341c82cf69 100644 (file)
@@ -71,6 +71,18 @@ class PoolListener(object):
 
         """
 
+    def first_connect(self, dbapi_con, con_record):
+        """Called exactly once for the first DB-API connection.
+
+        dbapi_con
+          A newly connected raw DB-API connection (not a SQLAlchemy
+          ``Connection`` wrapper).
+
+        con_record
+          The ``_ConnectionRecord`` that persistently manages the connection
+
+        """
+
     def checkout(self, dbapi_con, con_record, con_proxy):
         """Called when a connection is retrieved from the Pool.
 
index 2a20b05ef8430f05c370ae9b637bd28337fe51c5..3c39316da39a3abcd38924460187cd75a76dad3c 100644 (file)
@@ -849,8 +849,13 @@ def clear_mappers():
     """
     mapperlib._COMPILE_MUTEX.acquire()
     try:
-        for mapper in list(_mapper_registry):
-            mapper.dispose()
+        while _mapper_registry:
+            try:
+                # can't even reliably call list(weakdict) in jython
+                mapper, b = _mapper_registry.popitem()
+                mapper.dispose()
+            except KeyError:
+                pass
     finally:
         mapperlib._COMPILE_MUTEX.release()
 
index 46e9b00de2bbc389e87a1f7a0b552082a478aa49..f6947dbc1156abf2e723ad0a422e7f724b36d61f 100644 (file)
@@ -159,7 +159,7 @@ class InstrumentedAttribute(QueryableAttribute):
 
 class _ProxyImpl(object):
     accepts_scalar_loader = False
-    dont_expire_missing = False
+    expire_missing = True
     
     def __init__(self, key):
         self.key = key
@@ -231,7 +231,7 @@ class AttributeImpl(object):
     def __init__(self, class_, key,
                     callable_, trackparent=False, extension=None,
                     compare_function=None, active_history=False, parent_token=None, 
-                    dont_expire_missing=False,
+                    expire_missing=True,
                     **kwargs):
         """Construct an AttributeImpl.
 
@@ -269,8 +269,8 @@ class AttributeImpl(object):
           Allows multiple AttributeImpls to all match a single 
           owner attribute.
           
-        dont_expire_missing
-          if True, don't add an "expiry" callable to this attribute
+        expire_missing
+          if False, don't add an "expiry" callable to this attribute
           during state.expire_attributes(None), if no value is present 
           for this key.
           
@@ -290,7 +290,7 @@ class AttributeImpl(object):
                 active_history = True
                 break
         self.active_history = active_history
-        self.dont_expire_missing = dont_expire_missing
+        self.expire_missing = expire_missing
         
     def hasparent(self, state, optimistic=False):
         """Return the boolean value of a `hasparent` flag attached to the given item.
@@ -991,9 +991,8 @@ class ClassManager(dict):
             self.local_attrs[key] = inst
             self.install_descriptor(key, inst)
         self[key] = inst
+        
         for cls in self.class_.__subclasses__():
-            if isinstance(cls, types.ClassType):
-                continue
             manager = self._subclass_manager(cls)
             manager.instrument_attribute(key, inst, True)
 
@@ -1013,8 +1012,6 @@ class ClassManager(dict):
         if key in self.mutable_attributes:
             self.mutable_attributes.remove(key)
         for cls in self.class_.__subclasses__():
-            if isinstance(cls, types.ClassType):
-                continue
             manager = self._subclass_manager(cls)
             manager.uninstrument_attribute(key, True)
 
@@ -1646,8 +1643,12 @@ def __init__(%(apply_pos)s):
     func_vars = util.format_argspec_init(original__init__, grouped=False)
     func_text = func_body % func_vars
 
+    # Py3K
+    #func_defaults = getattr(original__init__, '__defaults__', None)
+    # Py2K
     func = getattr(original__init__, 'im_func', original__init__)
     func_defaults = getattr(func, 'func_defaults', None)
+    # end Py2K
 
     env = locals().copy()
     exec func_text in env
index 4ca4c5719eb514c1e357b89effd2ca86b9208a82..6a770184681fe90520305164012b680ca94d8c7b 100644 (file)
@@ -529,7 +529,11 @@ class CollectionAdapter(object):
         if getattr(obj, '_sa_adapter', None) is not None:
             return getattr(obj, '_sa_adapter')
         elif setting_type == dict:
+            # Py3K
+            #return obj.values()
+            # Py2K
             return getattr(obj, 'itervalues', getattr(obj, 'values'))()
+            # end Py2K
         else:
             return iter(obj)
 
@@ -561,7 +565,9 @@ class CollectionAdapter(object):
 
     def __iter__(self):
         """Iterate over entities in the collection."""
-        return getattr(self._data(), '_sa_iterator')()
+        
+        # Py3K requires iter() here
+        return iter(getattr(self._data(), '_sa_iterator')())
 
     def __len__(self):
         """Count entities in the collection."""
@@ -938,22 +944,23 @@ def _list_decorators():
                 fn(self, index, value)
             else:
                 # slice assignment requires __delitem__, insert, __len__
-                if index.stop is None:
-                    stop = 0
-                elif index.stop < 0:
-                    stop = len(self) + index.stop
-                else:
-                    stop = index.stop
                 step = index.step or 1
-                rng = range(index.start or 0, stop, step)
+                start = index.start or 0
+                if start < 0:
+                    start += len(self)
+                stop = index.stop or len(self)
+                if stop < 0:
+                    stop += len(self)
+                
                 if step == 1:
-                    for i in rng:
-                        del self[index.start]
-                    i = index.start
-                    for item in value:
-                        self.insert(i, item)
-                        i += 1
+                    for i in xrange(start, stop, step):
+                        if len(self) > start:
+                            del self[start]
+                    
+                    for i, item in enumerate(value):
+                        self.insert(i + start, item)
                 else:
+                    rng = range(start, stop, step)
                     if len(value) != len(rng):
                         raise ValueError(
                             "attempt to assign sequence of size %s to "
@@ -980,6 +987,7 @@ def _list_decorators():
         _tidy(__delitem__)
         return __delitem__
 
+    # Py2K
     def __setslice__(fn):
         def __setslice__(self, start, end, values):
             for value in self[start:end]:
@@ -996,7 +1004,8 @@ def _list_decorators():
             fn(self, start, end)
         _tidy(__delslice__)
         return __delslice__
-
+    # end Py2K
+    
     def extend(fn):
         def extend(self, iterable):
             for value in iterable:
@@ -1319,9 +1328,14 @@ class InstrumentedSet(set):
 class InstrumentedDict(dict):
     """An instrumented version of the built-in dict."""
 
+    # Py3K
+    #__instrumentation__ = {
+    #    'iterator': 'values', }
+    # Py2K
     __instrumentation__ = {
         'iterator': 'itervalues', }
-
+    # end Py2K
+    
 __canned_instrumentation = {
     list: InstrumentedList,
     set: InstrumentedSet,
@@ -1338,8 +1352,13 @@ __interfaces = {
           'iterator': '__iter__',
           '_decorators': _set_decorators(), },
     # decorators are required for dicts and object collections.
+    # Py3K
+    #dict: {'iterator': 'values',
+    #       '_decorators': _dict_decorators(), },
+    # Py2K
     dict: {'iterator': 'itervalues',
            '_decorators': _dict_decorators(), },
+    # end Py2K
     # < 0.4 compatible naming, deprecated- use decorators instead.
     None: {}
     }
index f3820eb7cdae0b113584aa00f584d227ef1d2488..407a04ae4dcf180f3abd2089361f89b43385cda5 100644 (file)
@@ -27,7 +27,7 @@ def create_dependency_processor(prop):
     return types[prop.direction](prop)
 
 class DependencyProcessor(object):
-    no_dependencies = False
+    has_dependencies = True
 
     def __init__(self, prop):
         self.prop = prop
@@ -291,7 +291,7 @@ class DetectKeySwitch(DependencyProcessor):
     """a special DP that works for many-to-one relations, fires off for
     child items who have changed their referenced key."""
 
-    no_dependencies = True
+    has_dependencies = False
 
     def register_dependencies(self, uowcommit):
         pass
index 9076f610d722ebb65d12df68b5b764a0c4e5eb07..05af5d8ca7d8955fd3ed5b1530b9c79b8e8c3828 100644 (file)
@@ -7,7 +7,11 @@ class UnevaluatableError(Exception):
     pass
 
 _straight_ops = set(getattr(operators, op)
-                    for op in ('add', 'mul', 'sub', 'div', 'mod', 'truediv',
+                    for op in ('add', 'mul', 'sub', 
+                                # Py2K
+                                'div',
+                                # end Py2K 
+                                'mod', 'truediv',
                                'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
 
 
index 50301a13c3ba85a27fad3b70669b7ba43978d213..b7d4234f4c1001b31fc28cfe59d0990d98887d17 100644 (file)
@@ -45,7 +45,7 @@ class IdentityMap(dict):
         self._modified.discard(state)
 
     def _dirty_states(self):
-        return self._modified.union(s for s in list(self._mutable_attrs)
+        return self._modified.union(s for s in self._mutable_attrs.copy()
                                     if s.modified)
 
     def check_modified(self):
@@ -54,7 +54,7 @@ class IdentityMap(dict):
         if self._modified:
             return True
         else:
-            for state in list(self._mutable_attrs):
+            for state in self._mutable_attrs.copy():
                 if state.modified:
                     return True
         return False
@@ -145,34 +145,49 @@ class WeakInstanceDict(IdentityMap):
             return self[key]
         except KeyError:
             return default
-            
+    
+    # Py2K        
     def items(self):
         return list(self.iteritems())
 
     def iteritems(self):
         for state in dict.itervalues(self):
+    # end Py2K
+    # Py3K
+    #def items(self):
+    #    for state in dict.values(self):
             value = state.obj()
             if value is not None:
                 yield state.key, value
 
+    # Py2K
+    def values(self):
+        return list(self.itervalues())
+
     def itervalues(self):
         for state in dict.itervalues(self):
+    # end Py2K
+    # Py3K
+    #def values(self):
+    #    for state in dict.values(self):
             instance = state.obj()
             if instance is not None:
                 yield instance
 
-    def values(self):
-        return list(self.itervalues())
-
     def all_states(self):
+        # Py3K
+        # return list(dict.values(self))
+        
+        # Py2K
         return dict.values(self)
+        # end Py2K
     
     def prune(self):
         return 0
         
 class StrongInstanceDict(IdentityMap):
     def all_states(self):
-        return [attributes.instance_state(o) for o in self.values()]
+        return [attributes.instance_state(o) for o in self.itervalues()]
     
     def contains_state(self, state):
         return state.key in self and attributes.instance_state(self[state.key]) is state
@@ -212,7 +227,11 @@ class StrongInstanceDict(IdentityMap):
         
         ref_count = len(self)
         dirty = [s.obj() for s in self.all_states() if s.check_modified()]
-        keepers = weakref.WeakValueDictionary(self)
+
+        # work around http://bugs.python.org/issue6149
+        keepers = weakref.WeakValueDictionary()
+        keepers.update(self)
+
         dict.clear(self)
         dict.update(self, keepers)
         self.modified = bool(dirty)
index 5dffa6774a5efe898bd66fa29990d9915201081f..eaafe5761a43d8baed85e86f2c433ca583e36fb0 100644 (file)
@@ -455,7 +455,7 @@ class MapperProperty(object):
 
         return not self.parent.non_primary
 
-    def merge(self, session, source, dest, dont_load, _recursive):
+    def merge(self, session, source, dest, load, _recursive):
         """Merge the attribute represented by this ``MapperProperty``
         from source to destination object"""
 
index 078056a01c4272bdedac424871ce655ab163bad4..c2c57825e33ea0ff5e926a8a6ff5712870f439ba 100644 (file)
@@ -506,13 +506,13 @@ class Mapper(object):
         if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
             col = self.mapped_table.corresponding_column(self.polymorphic_on)
             if not col:
-                dont_instrument = True
+                instrument = False
                 col = self.polymorphic_on
             else:
-                dont_instrument = False
+                instrument = True
             if self._should_exclude(col.key, col.key, local=False):
                 raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key)
-            self._configure_property(col.key, ColumnProperty(col, _no_instrument=dont_instrument), init=False, setparent=True)
+            self._configure_property(col.key, ColumnProperty(col, _instrument=instrument), init=False, setparent=True)
 
     def _adapt_inherited_property(self, key, prop, init):
         if not self.concrete:
@@ -1397,7 +1397,7 @@ class Mapper(object):
                 statement = table.insert()
                 for state, params, mapper, connection, value_params in insert:
                     c = connection.execute(statement.values(value_params), params)
-                    primary_key = c.last_inserted_ids()
+                    primary_key = c.inserted_primary_key
 
                     if primary_key is not None:
                         # set primary key attributes
@@ -1574,6 +1574,12 @@ class Mapper(object):
                 if state.load_options:
                     state.load_path = context.query._current_path + path
 
+            if isnew:
+                if context.options:
+                    state.load_options = context.options
+                if state.load_options:
+                    state.load_path = context.query._current_path + path
+
             if not new_populators:
                 new_populators[:], existing_populators[:] = self._populators(context, path, row, adapter)
 
index 0fa32f73f64e1317b817563263a5781f18bce394..3489d81f2e02c8ac7789ced0009bfe0888e89b58 100644 (file)
@@ -53,7 +53,7 @@ class ColumnProperty(StrategizedProperty):
         self.columns = [expression._labeled(c) for c in columns]
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
-        self.no_instrument = kwargs.pop('_no_instrument', False)
+        self.instrument = kwargs.pop('_instrument', True)
         self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator)
         self.descriptor = kwargs.pop('descriptor', None)
         self.extension = kwargs.pop('extension', None)
@@ -63,7 +63,7 @@ class ColumnProperty(StrategizedProperty):
                     self.__class__.__name__, ', '.join(sorted(kwargs.keys()))))
 
         util.set_creation_order(self)
-        if self.no_instrument:
+        if not self.instrument:
             self.strategy_class = strategies.UninstrumentedColumnLoader
         elif self.deferred:
             self.strategy_class = strategies.DeferredColumnLoader
@@ -71,7 +71,7 @@ class ColumnProperty(StrategizedProperty):
             self.strategy_class = strategies.ColumnLoader
     
     def instrument_class(self, mapper):
-        if self.no_instrument:
+        if not self.instrument:
             return
         
         attributes.register_descriptor(
@@ -104,7 +104,7 @@ class ColumnProperty(StrategizedProperty):
     def setattr(self, state, value, column):
         state.get_impl(self.key).set(state, state.dict, value, None)
 
-    def merge(self, session, source, dest, dont_load, _recursive):
+    def merge(self, session, source, dest, load, _recursive):
         value = attributes.instance_state(source).value_as_iterable(
             self.key, passive=True)
         if value:
@@ -302,7 +302,7 @@ class SynonymProperty(MapperProperty):
             proxy_property=self.descriptor
             )
 
-    def merge(self, session, source, dest, dont_load, _recursive):
+    def merge(self, session, source, dest, load, _recursive):
         pass
         
 log.class_logger(SynonymProperty)
@@ -335,7 +335,7 @@ class ComparableProperty(MapperProperty):
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         return (None, None)
 
-    def merge(self, session, source, dest, dont_load, _recursive):
+    def merge(self, session, source, dest, load, _recursive):
         pass
 
 
@@ -627,8 +627,8 @@ class RelationProperty(StrategizedProperty):
     def __str__(self):
         return str(self.parent.class_.__name__) + "." + self.key
 
-    def merge(self, session, source, dest, dont_load, _recursive):
-        if not dont_load:
+    def merge(self, session, source, dest, load, _recursive):
+        if load:
             # TODO: no test coverage for recursive check
             for r in self._reverse_property:
                 if (source, r) in _recursive:
@@ -650,10 +650,10 @@ class RelationProperty(StrategizedProperty):
             dest_list = []
             for current in instances:
                 _recursive[(current, self)] = True
-                obj = session._merge(current, dont_load=dont_load, _recursive=_recursive)
+                obj = session._merge(current, load=load, _recursive=_recursive)
                 if obj is not None:
                     dest_list.append(obj)
-            if dont_load:
+            if not load:
                 coll = attributes.init_collection(dest_state, self.key)
                 for c in dest_list:
                     coll.append_without_event(c)
@@ -663,9 +663,9 @@ class RelationProperty(StrategizedProperty):
             current = instances[0]
             if current is not None:
                 _recursive[(current, self)] = True
-                obj = session._merge(current, dont_load=dont_load, _recursive=_recursive)
+                obj = session._merge(current, load=load, _recursive=_recursive)
                 if obj is not None:
-                    if dont_load:
+                    if not load:
                         dest_state.dict[self.key] = obj
                     else:
                         setattr(dest, self.key, obj)
index e764856bf26b52b0e1bffef4263dbe4d64048019..21137bc28b19c90eec056c428d5447e28d1e671d 100644 (file)
@@ -648,7 +648,11 @@ class Query(object):
     def value(self, column):
         """Return a scalar result corresponding to the given column expression."""
         try:
+            # Py3K
+            #return self.values(column).__next__()[0]
+            # Py2K
             return self.values(column).next()[0]
+            # end Py2K
         except StopIteration:
             return None
 
@@ -1433,7 +1437,7 @@ class Query(object):
 
             session._finalize_loaded(context.progress)
 
-            for ii, (dict_, attrs) in context.partials.items():
+            for ii, (dict_, attrs) in context.partials.iteritems():
                 ii.commit(dict_, attrs)
 
             for row in rows:
@@ -1582,7 +1586,7 @@ class Query(object):
             self.session._autoflush()
         return self.session.scalar(s, params=self._params, mapper=self._mapper_zero())
 
-    def delete(self, synchronize_session='fetch'):
+    def delete(self, synchronize_session='evaluate'):
         """Perform a bulk delete query.
 
         Deletes rows matched by this query from the database.
@@ -1592,15 +1596,15 @@ class Query(object):
 
             False
               don't synchronize the session. This option is the most efficient and is reliable
-              once the session is expired, which typically occurs after a commit().   Before
-              the expiration, objects may still remain in the session which were in fact deleted
-              which can lead to confusing results if they are accessed via get() or already
-              loaded collections.
+              once the session is expired, which typically occurs after a commit(), or explicitly
+              using expire_all().  Before the expiration, objects may still remain in the session 
+              which were in fact deleted which can lead to confusing results if they are accessed 
+              via get() or already loaded collections.
 
             'fetch'
               performs a select query before the delete to find objects that are matched
               by the delete query and need to be removed from the session. Matched objects
-              are removed from the session. 'fetch' is the default strategy.
+              are removed from the session.
 
             'evaluate'
               experimental feature. Tries to evaluate the querys criteria in Python
@@ -1642,9 +1646,10 @@ class Query(object):
         if synchronize_session == 'evaluate':
             try:
                 evaluator_compiler = evaluator.EvaluatorCompiler()
-                eval_condition = evaluator_compiler.process(self.whereclause)
+                eval_condition = evaluator_compiler.process(self.whereclause or expression._Null)
             except evaluator.UnevaluatableError:
-                synchronize_session = 'fetch'
+                raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python.  "
+                        "Specify 'fetch' or False for the synchronize_session parameter.")
 
         delete_stmt = sql.delete(primary_table, context.whereclause)
 
@@ -1677,7 +1682,7 @@ class Query(object):
 
         return result.rowcount
 
-    def update(self, values, synchronize_session='expire'):
+    def update(self, values, synchronize_session='evaluate'):
         """Perform a bulk update query.
 
         Updates rows matched by this query in the database.
@@ -1689,18 +1694,19 @@ class Query(object):
             attributes on objects in the session. Valid values are:
 
             False
-              don't synchronize the session. Use this when you don't need to use the
-              session after the update or you can be sure that none of the matched objects
-              are in the session.
-
-            'expire'
+              don't synchronize the session. This option is the most efficient and is reliable
+              once the session is expired, which typically occurs after a commit(), or explicitly
+              using expire_all().  Before the expiration, updated objects may still remain in the session 
+              with stale values on their attributes, which can lead to confusing results.
+              
+            'fetch'
               performs a select query before the update to find objects that are matched
               by the update query. The updated attributes are expired on matched objects.
 
             'evaluate'
-              experimental feature. Tries to evaluate the querys criteria in Python
+              Tries to evaluate the Query's criteria in Python
               straight on the objects in the session. If evaluation of the criteria isn't
-              implemented, the 'expire' strategy will be used as a fallback.
+              implemented, an exception is raised.
 
               The expression evaluator currently doesn't account for differing string
               collations between the database and Python.
@@ -1709,6 +1715,7 @@ class Query(object):
 
         The method does *not* offer in-Python cascading of relations - it is assumed that
         ON UPDATE CASCADE is configured for any foreign key references which require it.
+
         The Session needs to be expired (occurs automatically after commit(), or call expire_all())
         in order for the state of dependent objects subject foreign key cascade to be
         correctly represented.
@@ -1723,9 +1730,9 @@ class Query(object):
         #TODO: updates of manytoone relations need to be converted to fk assignments
         #TODO: cascades need handling.
 
+        if synchronize_session not in [False, 'evaluate', 'fetch']:
+            raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'fetch'")
         self._no_select_modifiers("update")
-        if synchronize_session not in [False, 'evaluate', 'expire']:
-            raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'expire'")
 
         self = self.enable_eagerloads(False)
 
@@ -1739,18 +1746,19 @@ class Query(object):
         if synchronize_session == 'evaluate':
             try:
                 evaluator_compiler = evaluator.EvaluatorCompiler()
-                eval_condition = evaluator_compiler.process(self.whereclause)
+                eval_condition = evaluator_compiler.process(self.whereclause or expression._Null)
 
                 value_evaluators = {}
-                for key,value in values.items():
+                for key,value in values.iteritems():
                     key = expression._column_as_key(key)
                     value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value))
             except evaluator.UnevaluatableError:
-                synchronize_session = 'expire'
+                raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python.  "
+                        "Specify 'fetch' or False for the synchronize_session parameter.")
 
         update_stmt = sql.update(primary_table, context.whereclause, values)
 
-        if synchronize_session == 'expire':
+        if synchronize_session == 'fetch':
             select_stmt = context.statement.with_only_columns(primary_table.primary_key)
             matched_rows = session.execute(select_stmt, params=self._params).fetchall()
 
@@ -1777,7 +1785,7 @@ class Query(object):
                     # expire attributes with pending changes (there was no autoflush, so they are overwritten)
                     state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
 
-        elif synchronize_session == 'expire':
+        elif synchronize_session == 'fetch':
             target_mapper = self._mapper_zero()
 
             for primary_key in matched_rows:
@@ -2142,7 +2150,7 @@ class _ColumnEntity(_QueryEntity):
             return entity is self.entity_zero
         else:
             return not _is_aliased_class(self.entity_zero) and entity.base_mapper.common_parent(self.entity_zero)
-            
+
     def _resolve_expr_against_query_aliases(self, query, expr, context):
         return query._adapt_clause(expr, False, True)
 
index 4339b68ebc5f4de4c12a88534deee5abd1c21293..28eb63819ec4ba5a52cb754c869f844383116415 100644 (file)
@@ -171,7 +171,7 @@ class _ScopedExt(MapperExtension):
 
     def _default__init__(ext, mapper):
         def __init__(self, **kwargs):
-            for key, value in kwargs.items():
+            for key, value in kwargs.iteritems():
                 if ext.validate:
                     if not mapper.get_property(key, resolve_synonyms=False,
                                                raiseerr=False):
index c010a217bd4fd1c7ba8d265d0f4a49773d2679c8..d3d653de4f6c051e1c158a0813fa465451c7917e 100644 (file)
@@ -109,9 +109,9 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False,
       like::
 
         sess = Session(binds={
-            SomeMappedClass: create_engine('postgres://engine1'),
-            somemapper: create_engine('postgres://engine2'),
-            some_table: create_engine('postgres://engine3'),
+            SomeMappedClass: create_engine('postgresql://engine1'),
+            somemapper: create_engine('postgresql://engine2'),
+            some_table: create_engine('postgresql://engine3'),
             })
 
       Also see the ``bind_mapper()`` and ``bind_table()`` methods.
@@ -1139,7 +1139,7 @@ class Session(object):
         for state, m, o in cascade_states:
             self._delete_impl(state)
 
-    def merge(self, instance, dont_load=False):
+    def merge(self, instance, load=True, **kw):
         """Copy the state an instance onto the persistent instance with the same identifier.
 
         If there is no persistent instance currently associated with the
@@ -1152,6 +1152,10 @@ class Session(object):
         mapped with ``cascade="merge"``.
 
         """
+        if 'dont_load' in kw:
+            load = not kw['dont_load']
+            util.warn_deprecated("dont_load=True has been renamed to load=False.")
+        
         # TODO: this should be an IdentityDict for instances, but will
         # need a separate dict for PropertyLoader tuples
         _recursive = {}
@@ -1159,11 +1163,11 @@ class Session(object):
         autoflush = self.autoflush
         try:
             self.autoflush = False
-            return self._merge(instance, dont_load=dont_load, _recursive=_recursive)
+            return self._merge(instance, load=load, _recursive=_recursive)
         finally:
             self.autoflush = autoflush
         
-    def _merge(self, instance, dont_load=False, _recursive=None):
+    def _merge(self, instance, load=True, _recursive=None):
         mapper = _object_mapper(instance)
         if instance in _recursive:
             return _recursive[instance]
@@ -1173,24 +1177,24 @@ class Session(object):
         key = state.key
 
         if key is None:
-            if dont_load:
+            if not load:
                 raise sa_exc.InvalidRequestError(
-                    "merge() with dont_load=True option does not support "
+                    "merge() with load=False option does not support "
                     "objects transient (i.e. unpersisted) objects.  flush() "
                     "all changes on mapped instances before merging with "
-                    "dont_load=True.")
+                    "load=False.")
             key = mapper._identity_key_from_state(state)
 
         merged = None
         if key:
             if key in self.identity_map:
                 merged = self.identity_map[key]
-            elif dont_load:
+            elif not load:
                 if state.modified:
                     raise sa_exc.InvalidRequestError(
-                        "merge() with dont_load=True option does not support "
+                        "merge() with load=False option does not support "
                         "objects marked as 'dirty'.  flush() all changes on "
-                        "mapped instances before merging with dont_load=True.")
+                        "mapped instances before merging with load=False.")
                 merged = mapper.class_manager.new_instance()
                 merged_state = attributes.instance_state(merged)
                 merged_state.key = key
@@ -1208,9 +1212,9 @@ class Session(object):
         _recursive[instance] = merged
 
         for prop in mapper.iterate_properties:
-            prop.merge(self, instance, merged, dont_load, _recursive)
+            prop.merge(self, instance, merged, load, _recursive)
 
-        if dont_load:
+        if not load:
             attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map)  # remove any history
 
         if new_instance:
index 10a0f43eebaff6e56c65913697e39b6ca898d676..4d9fa5ade8532c22f6cf53e64d0a2db9577d136e 100644 (file)
@@ -31,11 +31,17 @@ class InstanceState(object):
         
     def detach(self):
         if self.session_id:
-            del self.session_id
+            try:
+                del self.session_id
+            except AttributeError:
+                pass
 
     def dispose(self):
         if self.session_id:
-            del self.session_id
+            try:
+                del self.session_id
+            except AttributeError:
+                pass
         del self.obj
     
     def _cleanup(self, ref):
@@ -230,7 +236,7 @@ class InstanceState(object):
         for key in attribute_names:
             impl = self.manager[key].impl
             if not filter_deferred or \
-                not impl.dont_expire_missing or \
+                impl.expire_missing or \
                 key in dict_:
                 self.expired_attributes.add(key)
                 if impl.accepts_scalar_loader:
index f739fb1dd0ac44c603b77c35bf905f6308d19587..e19e8fb31c052a7b3d8eb139b0809d32dd721e5e 100644 (file)
@@ -45,6 +45,7 @@ def _register_attribute(strategy, mapper, useobject,
     
     if useobject:
         attribute_ext.append(sessionlib.UOWEventHandler(prop.key))
+
     
     for m in mapper.polymorphic_iterator():
         if prop is m._props.get(prop.key):
@@ -235,7 +236,7 @@ class DeferredColumnLoader(LoaderStrategy):
              copy_function=self.columns[0].type.copy_value,
              mutable_scalars=self.columns[0].type.is_mutable(),
              callable_=self._class_level_loader,
-             dont_expire_missing=True
+             expire_missing=False
         )
 
     def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
index d650f65a5456cdd62d88c13461f256cc548ea816..bca6b4f463da07ac973ec50ae4232491bcfc3ab4 100644 (file)
@@ -531,7 +531,7 @@ class UOWTask(object):
                     if subtask not in deps_by_targettask:
                         continue
                     for dep in deps_by_targettask[subtask]:
-                        if dep.processor.no_dependencies or not dependency_in_cycles(dep):
+                        if not dep.processor.has_dependencies or not dependency_in_cycles(dep):
                             continue
                         (processor, targettask) = (dep.processor, dep.targettask)
                         isdelete = taskelement.isdelete
index c4e1af20cf3fb27786d6154f2c7913e32c508315..dabdc6e35345f0b339ad1dd611ee4fc8f5de30bd 100644 (file)
@@ -19,7 +19,7 @@ SQLAlchemy connection pool.
 import weakref, time, threading
 
 from sqlalchemy import exc, log
-from sqlalchemy import queue as Queue
+from sqlalchemy import queue as sqla_queue
 from sqlalchemy.util import threading, pickle, as_interface
 
 proxies = {}
@@ -51,7 +51,7 @@ def clear_managers():
     All pools and connections are disposed.
     """
 
-    for manager in proxies.values():
+    for manager in proxies.itervalues():
         manager.close()
     proxies.clear()
 
@@ -108,6 +108,7 @@ class Pool(object):
         self.echo = echo
         self.listeners = []
         self._on_connect = []
+        self._on_first_connect = []
         self._on_checkout = []
         self._on_checkin = []
 
@@ -178,12 +179,14 @@ class Pool(object):
 
         """
 
-        listener = as_interface(
-            listener, methods=('connect', 'checkout', 'checkin'))
+        listener = as_interface(listener,
+            methods=('connect', 'first_connect', 'checkout', 'checkin'))
 
         self.listeners.append(listener)
         if hasattr(listener, 'connect'):
             self._on_connect.append(listener)
+        if hasattr(listener, 'first_connect'):
+            self._on_first_connect.append(listener)
         if hasattr(listener, 'checkout'):
             self._on_checkout.append(listener)
         if hasattr(listener, 'checkin'):
@@ -197,6 +200,10 @@ class _ConnectionRecord(object):
         self.__pool = pool
         self.connection = self.__connect()
         self.info = {}
+        ls = pool.__dict__.pop('_on_first_connect', None)
+        if ls is not None:
+            for l in ls:
+                l.first_connect(self.connection, self)
         if pool._on_connect:
             for l in pool._on_connect:
                 l.connect(self.connection, self)
@@ -269,8 +276,11 @@ class _ConnectionRecord(object):
 
 
 def _finalize_fairy(connection, connection_record, pool, ref=None):
-    if ref is not None and connection_record.backref is not ref:
+    _refs.discard(connection_record)
+        
+    if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)):
         return
+
     if connection is not None:
         try:
             if pool._reset_on_return:
@@ -284,7 +294,7 @@ def _finalize_fairy(connection, connection_record, pool, ref=None):
             if isinstance(e, (SystemExit, KeyboardInterrupt)):
                 raise
     if connection_record is not None:
-        connection_record.backref = None
+        connection_record.fairy = None
         if pool._should_log_info:
             pool.log("Connection %r being returned to pool" % connection)
         if pool._on_checkin:
@@ -292,6 +302,8 @@ def _finalize_fairy(connection, connection_record, pool, ref=None):
                 l.checkin(connection, connection_record)
         pool.return_conn(connection_record)
 
+_refs = set()
+
 class _ConnectionFairy(object):
     """Proxies a DB-API connection and provides return-on-dereference support."""
 
@@ -303,7 +315,8 @@ class _ConnectionFairy(object):
         try:
             rec = self._connection_record = pool.get()
             conn = self.connection = self._connection_record.get_connection()
-            self._connection_record.backref = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref))
+            rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref))
+            _refs.add(rec)
         except:
             self.connection = None # helps with endless __getattr__ loops later on
             self._connection_record = None
@@ -402,8 +415,9 @@ class _ConnectionFairy(object):
         """
 
         if self._connection_record is not None:
+            _refs.remove(self._connection_record)
+            self._connection_record.fairy = None
             self._connection_record.connection = None
-            self._connection_record.backref = None
             self._pool.do_return_conn(self._connection_record)
             self._detached_info = \
               self._connection_record.info.copy()
@@ -501,10 +515,8 @@ class SingletonThreadPool(Pool):
             del self._conn.current
 
     def cleanup(self):
-        for conn in list(self._all_conns):
-            self._all_conns.discard(conn)
-            if len(self._all_conns) <= self.size:
-                return
+        while len(self._all_conns) > self.size:
+            self._all_conns.pop()
 
     def status(self):
         return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns))
@@ -593,7 +605,7 @@ class QueuePool(Pool):
 
         """
         Pool.__init__(self, creator, **params)
-        self._pool = Queue.Queue(pool_size)
+        self._pool = sqla_queue.Queue(pool_size)
         self._overflow = 0 - pool_size
         self._max_overflow = max_overflow
         self._timeout = timeout
@@ -606,7 +618,7 @@ class QueuePool(Pool):
     def do_return_conn(self, conn):
         try:
             self._pool.put(conn, False)
-        except Queue.Full:
+        except sqla_queue.Full:
             if self._overflow_lock is None:
                 self._overflow -= 1
             else:
@@ -620,7 +632,7 @@ class QueuePool(Pool):
         try:
             wait = self._max_overflow > -1 and self._overflow >= self._max_overflow
             return self._pool.get(wait, self._timeout)
-        except Queue.Empty:
+        except sqla_queue.Empty:
             if self._max_overflow > -1 and self._overflow >= self._max_overflow:
                 if not wait:
                     return self.do_get()
@@ -648,7 +660,7 @@ class QueuePool(Pool):
             try:
                 conn = self._pool.get(False)
                 conn.close()
-            except Queue.Empty:
+            except sqla_queue.Empty:
                 break
 
         self._overflow = 0 - self.size()
@@ -747,7 +759,8 @@ class StaticPool(Pool):
         Pool.__init__(self, creator, **params)
         self._conn = creator()
         self.connection = _ConnectionRecord(self)
-
+        self.connection = None
+        
     def status(self):
         return "StaticPool"
 
@@ -788,68 +801,41 @@ class AssertionPool(Pool):
 
     ## TODO: modify this to handle an arbitrary connection count.
 
-    def __init__(self, creator, **params):
-        """
-        Construct an AssertionPool.
-
-        :param creator: a callable function that returns a DB-API
-          connection object.  The function will be called with
-          parameters.
-
-        :param recycle: If set to non -1, number of seconds between
-          connection recycling, which means upon checkout, if this
-          timeout is surpassed the connection will be closed and
-          replaced with a newly opened connection. Defaults to -1.
-
-        :param echo: If True, connections being pulled and retrieved
-          from the pool will be logged to the standard output, as well
-          as pool sizing information.  Echoing can also be achieved by
-          enabling logging for the "sqlalchemy.pool"
-          namespace. Defaults to False.
-
-        :param use_threadlocal: If set to True, repeated calls to
-          :meth:`connect` within the same application thread will be
-          guaranteed to return the same connection object, if one has
-          already been retrieved from the pool and has not been
-          returned yet.  Offers a slight performance advantage at the
-          cost of individual transactions by default.  The
-          :meth:`unique_connection` method is provided to bypass the
-          threadlocal behavior installed into :meth:`connect`.
-
-        :param reset_on_return: If true, reset the database state of
-          connections returned to the pool.  This is typically a
-          ROLLBACK to release locks and transaction resources.
-          Disable at your own peril.  Defaults to True.
-
-        :param listeners: A list of
-          :class:`~sqlalchemy.interfaces.PoolListener`-like objects or
-          dictionaries of callables that receive events when DB-API
-          connections are created, checked out and checked in to the
-          pool.
-
-        """
-        Pool.__init__(self, creator, **params)
-        self.connection = _ConnectionRecord(self)
-        self._conn = self.connection
-
+    def __init__(self, *args, **kw):
+        self._conn = None
+        self._checked_out = False
+        Pool.__init__(self, *args, **kw)
+        
     def status(self):
         return "AssertionPool"
 
-    def create_connection(self):
-        raise AssertionError("Invalid")
-
     def do_return_conn(self, conn):
-        assert conn is self._conn and self.connection is None
-        self.connection = conn
+        if not self._checked_out:
+            raise AssertionError("connection is not checked out")
+        self._checked_out = False
+        assert conn is self._conn
 
     def do_return_invalid(self, conn):
-        raise AssertionError("Invalid")
+        self._conn = None
+        self._checked_out = False
+    
+    def dispose(self):
+        self._checked_out = False
+        self._conn.close()
 
+    def recreate(self):
+        self.log("Pool recreating")
+        return AssertionPool(self._creator, echo=self._should_log_info, listeners=self.listeners)
+        
     def do_get(self):
-        assert self.connection is not None
-        c = self.connection
-        self.connection = None
-        return c
+        if self._checked_out:
+            raise AssertionError("connection is already checked out")
+            
+        if not self._conn:
+            self._conn = self.create_connection()
+        
+        self._checked_out = True
+        return self._conn
 
 class _DBProxy(object):
     """Layers connection pooling behavior on top of a standard DB-API module.
index c9ab82acf8a3b40f9fa87f20e4c768aff50096db..2aaeea9d0fc76ca15a8b7b110e3904d117df0d03 100644 (file)
@@ -2,7 +2,7 @@
 behavior, using RLock instead of Lock for its mutex object.
 
 This is to support the connection pool's usage of weakref callbacks to return
-connections to the underlying Queue, which can apparently in extremely
+connections to the underlying Queue, which can in extremely
 rare cases be invoked within the ``get()`` method of the Queue itself,
 producing a ``put()`` inside the ``get()`` and therefore a reentrant
 condition."""
index e641f119b349b1cf30f52cb9603f3bbab8b637aa..231496676c11c2f82fbebacd152edb0cedcf8171 100644 (file)
@@ -28,7 +28,7 @@ expressions.
 
 """
 import re, inspect
-from sqlalchemy import types, exc, util, databases
+from sqlalchemy import types, exc, util, dialects
 from sqlalchemy.sql import expression, visitors
 
 URL = None
@@ -65,13 +65,6 @@ class SchemaItem(visitors.Visitable):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
-    @property
-    def bind(self):
-        """Return the connectable associated with this SchemaItem."""
-
-        m = self.metadata
-        return m and m.bind or None
-
     @util.memoized_property
     def info(self):
         return {}
@@ -82,127 +75,140 @@ def _get_table_key(name, schema):
     else:
         return schema + "." + name
 
-class _TableSingleton(visitors.VisitableType):
-    """A metaclass used by the ``Table`` object to provide singleton behavior."""
+class Table(SchemaItem, expression.TableClause):
+    """Represent a table in a database.
+    
+    e.g.::
+    
+        mytable = Table("mytable", metadata, 
+                        Column('mytable_id', Integer, primary_key=True),
+                        Column('value', String(50))
+                   )
+
+    The Table object constructs a unique instance of itself based on its
+    name within the given MetaData object.   Constructor
+    arguments are as follows:
+    
+    :param name: The name of this table as represented in the database. 
+
+        This property, along with the *schema*, indicates the *singleton
+        identity* of this table in relation to its parent :class:`MetaData`.
+        Additional calls to :class:`Table` with the same name, metadata,
+        and schema name will return the same :class:`Table` object.
+
+        Names which contain no upper case characters
+        will be treated as case insensitive names, and will not be quoted
+        unless they are a reserved word.  Names with any number of upper
+        case characters will be quoted and sent exactly.  Note that this
+        behavior applies even for databases which standardize upper 
+        case names as case insensitive such as Oracle.
+
+    :param metadata: a :class:`MetaData` object which will contain this 
+        table.  The metadata is used as a point of association of this table
+        with other tables which are referenced via foreign key.  It also
+        may be used to associate this table with a particular 
+        :class:`~sqlalchemy.engine.base.Connectable`.
+
+    :param \*args: Additional positional arguments are used primarily
+        to add the list of :class:`Column` objects contained within this table.
+        Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem`
+        constructs may be added here, including :class:`PrimaryKeyConstraint`,
+        and :class:`ForeignKeyConstraint`.
+        
+    :param autoload: Defaults to False: the Columns for this table should be reflected
+        from the database.  Usually there will be no Column objects in the
+        constructor if this property is set.
+
+    :param autoload_with: If autoload==True, this is an optional Engine or Connection
+        instance to be used for the table reflection.  If ``None``, the
+        underlying MetaData's bound connectable will be used.
+
+    :param implicit_returning: True by default - indicates that 
+        RETURNING can be used by default to fetch newly inserted primary key 
+        values, for backends which support this.  Note that 
+        create_engine() also provides an implicit_returning flag.
+
+    :param include_columns: A list of strings indicating a subset of columns to be loaded via
+        the ``autoload`` operation; table columns who aren't present in
+        this list will not be represented on the resulting ``Table``
+        object.  Defaults to ``None`` which indicates all columns should
+        be reflected.
+
+    :param info: A dictionary which defaults to ``{}``.  A space to store application 
+        specific data. This must be a dictionary.
+
+    :param mustexist: When ``True``, indicates that this Table must already 
+        be present in the given :class:`MetaData`` collection.
+
+    :param prefixes:
+        A list of strings to insert after CREATE in the CREATE TABLE
+        statement.  They will be separated by spaces.
+
+    :param quote: Force quoting of this table's name on or off, corresponding
+        to ``True`` or ``False``.  When left at its default of ``None``,
+        the column identifier will be quoted according to whether the name is
+        case sensitive (identifiers with at least one upper case character are 
+        treated as case sensitive), or if it's a reserved word.  This flag 
+        is only needed to force quoting of a reserved word which is not known
+        by the SQLAlchemy dialect.
+
+    :param quote_schema: same as 'quote' but applies to the schema identifier.
+
+    :param schema: The *schema name* for this table, which is required if the table
+        resides in a schema other than the default selected schema for the
+        engine's database connection.  Defaults to ``None``.
+
+    :param useexisting: When ``True``, indicates that if this Table is already
+        present in the given :class:`MetaData`, apply further arguments within
+        the constructor to the existing :class:`Table`.  If this flag is not 
+        set, an error is raised when the parameters of an existing :class:`Table`
+        are overwritten.
+
+    """
+    
+    __visit_name__ = 'table'
+
+    ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
 
-    def __call__(self, name, metadata, *args, **kwargs):
-        schema = kwargs.get('schema', kwargs.get('owner', None))
-        useexisting = kwargs.pop('useexisting', False)
-        mustexist = kwargs.pop('mustexist', False)
+    def __new__(cls, name, metadata, *args, **kw):
+        schema = kw.get('schema', None)
+        useexisting = kw.pop('useexisting', False)
+        mustexist = kw.pop('mustexist', False)
         key = _get_table_key(name, schema)
-        try:
-            table = metadata.tables[key]
-            if not useexisting and table._cant_override(*args, **kwargs):
+        if key in metadata.tables:
+            if not useexisting and bool(args):
                 raise exc.InvalidRequestError(
                     "Table '%s' is already defined for this MetaData instance.  "
                     "Specify 'useexisting=True' to redefine options and "
                     "columns on an existing Table object." % key)
-            else:
-                table._init_existing(*args, **kwargs)
+            table = metadata.tables[key]
+            table._init_existing(*args, **kw)
             return table
-        except KeyError:
+        else:
             if mustexist:
                 raise exc.InvalidRequestError(
                     "Table '%s' not defined" % (key))
+            metadata.tables[key] = table = object.__new__(cls)
             try:
-                return type.__call__(self, name, metadata, *args, **kwargs)
+                table._init(name, metadata, *args, **kw)
+                return table
             except:
-                if key in metadata.tables:
-                    del metadata.tables[key]
+                metadata.tables.pop(key)
                 raise
-
-
-class Table(SchemaItem, expression.TableClause):
-    """Represent a table in a database."""
-
-    __metaclass__ = _TableSingleton
-
-    __visit_name__ = 'table'
-
-    ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
-
-    def __init__(self, name, metadata, *args, **kwargs):
-        """
-        Construct a Table.
-
-        :param name: The name of this table as represented in the database. 
-
-            This property, along with the *schema*, indicates the *singleton
-            identity* of this table in relation to its parent :class:`MetaData`.
-            Additional calls to :class:`Table` with the same name, metadata,
-            and schema name will return the same :class:`Table` object.
-
-            Names which contain no upper case characters
-            will be treated as case insensitive names, and will not be quoted
-            unless they are a reserved word.  Names with any number of upper
-            case characters will be quoted and sent exactly.  Note that this
-            behavior applies even for databases which standardize upper 
-            case names as case insensitive such as Oracle.
-
-        :param metadata: a :class:`MetaData` object which will contain this 
-            table.  The metadata is used as a point of association of this table
-            with other tables which are referenced via foreign key.  It also
-            may be used to associate this table with a particular 
-            :class:`~sqlalchemy.engine.base.Connectable`.
-
-        :param \*args: Additional positional arguments are used primarily
-            to add the list of :class:`Column` objects contained within this table.
-            Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem`
-            constructs may be added here, including :class:`PrimaryKeyConstraint`,
-            and :class:`ForeignKeyConstraint`.
-            
-        :param autoload: Defaults to False: the Columns for this table should be reflected
-            from the database.  Usually there will be no Column objects in the
-            constructor if this property is set.
-
-        :param autoload_with: If autoload==True, this is an optional Engine or Connection
-            instance to be used for the table reflection.  If ``None``, the
-            underlying MetaData's bound connectable will be used.
-
-        :param include_columns: A list of strings indicating a subset of columns to be loaded via
-            the ``autoload`` operation; table columns who aren't present in
-            this list will not be represented on the resulting ``Table``
-            object.  Defaults to ``None`` which indicates all columns should
-            be reflected.
-
-        :param info: A dictionary which defaults to ``{}``.  A space to store application 
-            specific data. This must be a dictionary.
-
-        :param mustexist: When ``True``, indicates that this Table must already 
-            be present in the given :class:`MetaData`` collection.
-
-        :param prefixes:
-            A list of strings to insert after CREATE in the CREATE TABLE
-            statement.  They will be separated by spaces.
-
-        :param quote: Force quoting of this table's name on or off, corresponding
-            to ``True`` or ``False``.  When left at its default of ``None``,
-            the column identifier will be quoted according to whether the name is
-            case sensitive (identifiers with at least one upper case character are 
-            treated as case sensitive), or if it's a reserved word.  This flag 
-            is only needed to force quoting of a reserved word which is not known
-            by the SQLAlchemy dialect.
-
-        :param quote_schema: same as 'quote' but applies to the schema identifier.
-
-        :param schema: The *schema name* for this table, which is required if the table
-            resides in a schema other than the default selected schema for the
-            engine's database connection.  Defaults to ``None``.
-
-        :param useexisting: When ``True``, indicates that if this Table is already
-            present in the given :class:`MetaData`, apply further arguments within
-            the constructor to the existing :class:`Table`.  If this flag is not 
-            set, an error is raised when the parameters of an existing :class:`Table`
-            are overwritten.
-
-        """
+                
+    def __init__(self, *args, **kw):
+        # __init__ is overridden to prevent __new__ from 
+        # calling the superclass constructor.
+        pass
+        
+    def _init(self, name, metadata, *args, **kwargs):
         super(Table, self).__init__(name)
         self.metadata = metadata
-        self.schema = kwargs.pop('schema', kwargs.pop('owner', None))
+        self.schema = kwargs.pop('schema', None)
         self.indexes = set()
         self.constraints = set()
         self._columns = expression.ColumnCollection()
-        self.primary_key = PrimaryKeyConstraint()
+        self._set_primary_key(PrimaryKeyConstraint())
         self._foreign_keys = util.OrderedSet()
         self.ddl_listeners = util.defaultdict(list)
         self.kwargs = {}
@@ -215,8 +221,7 @@ class Table(SchemaItem, expression.TableClause):
         autoload_with = kwargs.pop('autoload_with', None)
         include_columns = kwargs.pop('include_columns', None)
 
-        self._set_parent(metadata)
-
+        self.implicit_returning = kwargs.pop('implicit_returning', True)
         self.quote = kwargs.pop('quote', None)
         self.quote_schema = kwargs.pop('quote_schema', None)
         if 'info' in kwargs:
@@ -224,7 +229,7 @@ class Table(SchemaItem, expression.TableClause):
 
         self._prefixes = kwargs.pop('prefixes', [])
 
-        self.__extra_kwargs(**kwargs)
+        self._extra_kwargs(**kwargs)
 
         # load column definitions from the database if 'autoload' is defined
         # we do it after the table is in the singleton dictionary to support
@@ -237,7 +242,7 @@ class Table(SchemaItem, expression.TableClause):
 
         # initialize all the column, etc. objects.  done after reflection to
         # allow user-overrides
-        self.__post_init(*args, **kwargs)
+        self._init_items(*args)
 
     def _init_existing(self, *args, **kwargs):
         autoload = kwargs.pop('autoload', False)
@@ -261,43 +266,43 @@ class Table(SchemaItem, expression.TableClause):
         if 'info' in kwargs:
             self.info = kwargs.pop('info')
 
-        self.__extra_kwargs(**kwargs)
-        self.__post_init(*args, **kwargs)
-
-    def _cant_override(self, *args, **kwargs):
-        """Return True if any argument is not supported as an override.
-
-        Takes arguments that would be sent to Table.__init__, and returns
-        True if any of them would be disallowed if sent to an existing
-        Table singleton.
-        """
-        return bool(args) or bool(set(kwargs).difference(
-            ['autoload', 'autoload_with', 'schema', 'owner']))
+        self._extra_kwargs(**kwargs)
+        self._init_items(*args)
 
-    def __extra_kwargs(self, **kwargs):
+    def _extra_kwargs(self, **kwargs):
         # validate remaining kwargs that they all specify DB prefixes
         if len([k for k in kwargs
-                if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]):
+                if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]):
             raise TypeError(
-                "Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
+                "Invalid argument(s) for Table: %r" % kwargs.keys())
         self.kwargs.update(kwargs)
 
-    def __post_init(self, *args, **kwargs):
-        self._init_items(*args)
-
-    @property
-    def key(self):
-        return _get_table_key(self.name, self.schema)
-
     def _set_primary_key(self, pk):
         if getattr(self, '_primary_key', None) in self.constraints:
             self.constraints.remove(self._primary_key)
         self._primary_key = pk
         self.constraints.add(pk)
 
+        for c in pk.columns:
+            c.primary_key = True
+
+    @util.memoized_property
+    def _autoincrement_column(self):
+        for col in self.primary_key:
+            if col.autoincrement and \
+                isinstance(col.type, types.Integer) and \
+                not col.foreign_keys and \
+                isinstance(col.default, (type(None), Sequence)):
+
+                return col
+
+    @property
+    def key(self):
+        return _get_table_key(self.name, self.schema)
+
+    @property
     def primary_key(self):
         return self._primary_key
-    primary_key = property(primary_key, _set_primary_key)
 
     def __repr__(self):
         return "Table(%s)" % ', '.join(
@@ -308,6 +313,12 @@ class Table(SchemaItem, expression.TableClause):
     def __str__(self):
         return _get_table_key(self.description, self.schema)
 
+    @property
+    def bind(self):
+        """Return the connectable associated with this Table."""
+
+        return self.metadata and self.metadata.bind or None
+
     def append_column(self, column):
         """Append a ``Column`` to this ``Table``."""
 
@@ -359,7 +370,7 @@ class Table(SchemaItem, expression.TableClause):
                 self, column_collections=column_collections, **kwargs)
         else:
             if column_collections:
-                return [c for c in self.columns]
+                return list(self.columns)
             else:
                 return []
 
@@ -407,7 +418,7 @@ class Column(SchemaItem, expression.ColumnClause):
     """Represents a column in a database table."""
 
     __visit_name__ = 'column'
-
+    
     def __init__(self, *args, **kwargs):
         """
         Construct a new ``Column`` object.
@@ -478,7 +489,7 @@ class Column(SchemaItem, expression.ColumnClause):
           
           Contrast this argument to ``server_default`` which creates a 
           default generator on the database side.
-
+        
         :param key: An optional string identifier which will identify this ``Column`` 
             object on the :class:`Table`.  When a key is provided, this is the
             only identifier referencing the ``Column`` within the application,
@@ -568,10 +579,6 @@ class Column(SchemaItem, expression.ColumnClause):
         if args:
             coltype = args[0]
             
-            # adjust for partials
-            if util.callable(coltype):
-                coltype = args[0]()
-
             if (isinstance(coltype, types.AbstractType) or
                 (isinstance(coltype, type) and
                  issubclass(coltype, types.AbstractType))):
@@ -581,7 +588,6 @@ class Column(SchemaItem, expression.ColumnClause):
                 type_ = args.pop(0)
 
         super(Column, self).__init__(name, None, type_)
-        self.args = args
         self.key = kwargs.pop('key', name)
         self.primary_key = kwargs.pop('primary_key', False)
         self.nullable = kwargs.pop('nullable', not self.primary_key)
@@ -595,6 +601,28 @@ class Column(SchemaItem, expression.ColumnClause):
         self.autoincrement = kwargs.pop('autoincrement', True)
         self.constraints = set()
         self.foreign_keys = util.OrderedSet()
+        self._table_events = set()
+        
+        if self.default is not None:
+            if isinstance(self.default, ColumnDefault):
+                args.append(self.default)
+            else:
+                args.append(ColumnDefault(self.default))
+        if self.server_default is not None:
+            if isinstance(self.server_default, FetchedValue):
+                args.append(self.server_default)
+            else:
+                args.append(DefaultClause(self.server_default))
+        if self.onupdate is not None:
+            args.append(ColumnDefault(self.onupdate, for_update=True))
+        if self.server_onupdate is not None:
+            if isinstance(self.server_onupdate, FetchedValue):
+                args.append(self.server_default)
+            else:
+                args.append(DefaultClause(self.server_onupdate,
+                                            for_update=True))
+        self._init_items(*args)
+
         util.set_creation_order(self)
 
         if 'info' in kwargs:
@@ -615,10 +643,6 @@ class Column(SchemaItem, expression.ColumnClause):
         else:
             return self.description
 
-    @property
-    def bind(self):
-        return self.table.bind
-
     def references(self, column):
         """Return True if this Column references the given column via foreign key."""
         for fk in self.foreign_keys:
@@ -658,24 +682,26 @@ class Column(SchemaItem, expression.ColumnClause):
                 "before adding to a Table.")
         if self.key is None:
             self.key = self.name
-        self.metadata = table.metadata
+
         if getattr(self, 'table', None) is not None:
             raise exc.ArgumentError("this Column already has a table!")
 
         if self.key in table._columns:
-            # note the column being replaced, if any
-            self._pre_existing_column = table._columns.get(self.key)
+            col = table._columns.get(self.key)
+            for fk in col.foreign_keys:
+                col.foreign_keys.remove(fk)
+                table.foreign_keys.remove(fk)
+                table.constraints.remove(fk.constraint)
+            
         table._columns.replace(self)
 
         if self.primary_key:
-            table.primary_key.replace(self)
+            table.primary_key._replace(self)
         elif self.key in table.primary_key:
             raise exc.ArgumentError(
                 "Trying to redefine primary-key column '%s' as a "
                 "non-primary-key column on table '%s'" % (
                 self.key, table.fullname))
-            # if we think this should not raise an error, we'd instead do this:
-            #table.primary_key.remove(self)
         self.table = table
 
         if self.index:
@@ -695,48 +721,47 @@ class Column(SchemaItem, expression.ColumnClause):
                     "external to the Table.")
             table.append_constraint(UniqueConstraint(self.key))
 
-        toinit = list(self.args)
-        if self.default is not None:
-            if isinstance(self.default, ColumnDefault):
-                toinit.append(self.default)
-            else:
-                toinit.append(ColumnDefault(self.default))
-        if self.server_default is not None:
-            if isinstance(self.server_default, FetchedValue):
-                toinit.append(self.server_default)
-            else:
-                toinit.append(DefaultClause(self.server_default))
-        if self.onupdate is not None:
-            toinit.append(ColumnDefault(self.onupdate, for_update=True))
-        if self.server_onupdate is not None:
-            if isinstance(self.server_onupdate, FetchedValue):
-                toinit.append(self.server_default)
-            else:
-                toinit.append(DefaultClause(self.server_onupdate,
-                                            for_update=True))
-        self._init_items(*toinit)
-        self.args = None
-
+        for fn in self._table_events:
+            fn(table)
+        del self._table_events
+    
+    def _on_table_attach(self, fn):
+        if self.table:
+            fn(self.table)
+        else:
+            self._table_events.add(fn)
+            
     def copy(self, **kw):
         """Create a copy of this ``Column``, unitialized.
 
         This is used in ``Table.tometadata``.
 
         """
-        return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy(**kw) for c in self.constraints])
+        return Column(
+                self.name, 
+                self.type, 
+                self.default, 
+                key = self.key, 
+                primary_key = self.primary_key, 
+                nullable = self.nullable, 
+                quote=self.quote, 
+                index=self.index, 
+                autoincrement=self.autoincrement, 
+                *[c.copy(**kw) for c in self.constraints])
 
     def _make_proxy(self, selectable, name=None):
         """Create a *proxy* for this column.
 
         This is a copy of this ``Column`` referenced by a different parent
-        (such as an alias or select statement).
-
+        (such as an alias or select statement).  The column should
+        be used only in select scenarios, as its full DDL/default
+        information is not transferred.
+        
         """
         fk = [ForeignKey(f.column) for f in self.foreign_keys]
         c = Column(
             name or self.name, 
             self.type, 
-            self.default, 
             key = name or self.key, 
             primary_key = self.primary_key, 
             nullable = self.nullable, 
@@ -746,7 +771,9 @@ class Column(SchemaItem, expression.ColumnClause):
         selectable.columns.add(c)
         if self.primary_key:
             selectable.primary_key.add(c)
-        [c._init_items(f) for f in fk]
+        for fn in c._table_events:
+            fn(selectable)
+        del c._table_events
         return c
 
     def get_children(self, schema_visitor=False, **kwargs):
@@ -775,9 +802,13 @@ class ForeignKey(SchemaItem):
 
     __visit_name__ = 'foreign_key'
 
-    def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None, link_to_name=False):
+    def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, 
+                    ondelete=None, deferrable=None, initially=None, link_to_name=False):
         """
-        Construct a column-level FOREIGN KEY.
+        Construct a column-level FOREIGN KEY.  
+        
+        The :class:`ForeignKey` object when constructed generates a :class:`ForeignKeyConstraint`
+        which is associated with the parent :class:`Table` object's collection of constraints.
 
         :param column: A single target column for the key relationship.  A :class:`Column`
           object or a column name as a string: ``tablename.columnkey`` or
@@ -809,10 +840,10 @@ class ForeignKey(SchemaItem):
         :param link_to_name: if True, the string name given in ``column`` is the rendered
           name of the referenced column, not its locally assigned ``key``.
           
-        :param use_alter: If True, do not emit this key as part of the CREATE TABLE
-          definition.  Instead, use ALTER TABLE after table creation to add
-          the key.  Useful for circular dependencies.
-          
+        :param use_alter: passed to the underlying :class:`ForeignKeyConstraint` to indicate the
+          constraint should be generated/dropped externally from the CREATE TABLE/
+          DROP TABLE statement.  See that classes' constructor for details.
+        
         """
 
         self._colspec = column
@@ -946,39 +977,35 @@ class ForeignKey(SchemaItem):
 
     def _set_parent(self, column):
         if hasattr(self, 'parent'):
+            if self.parent is column:
+                return
             raise exc.InvalidRequestError("This ForeignKey already has a parent !")
         self.parent = column
 
-        if hasattr(self.parent, '_pre_existing_column'):
-            # remove existing FK which matches us
-            for fk in self.parent._pre_existing_column.foreign_keys:
-                if fk.target_fullname == self.target_fullname:
-                    self.parent.table.foreign_keys.remove(fk)
-                    self.parent.table.constraints.remove(fk.constraint)
-
-        if self.constraint is None and isinstance(self.parent.table, Table):
+        self.parent.foreign_keys.add(self)
+        self.parent._on_table_attach(self._set_table)
+    
+    def _set_table(self, table):
+        if self.constraint is None and isinstance(table, Table):
             self.constraint = ForeignKeyConstraint(
                 [], [], use_alter=self.use_alter, name=self.name,
                 onupdate=self.onupdate, ondelete=self.ondelete,
-                deferrable=self.deferrable, initially=self.initially)
-            self.parent.table.append_constraint(self.constraint)
-            self.constraint._append_fk(self)
-
-        self.parent.foreign_keys.add(self)
-        self.parent.table.foreign_keys.add(self)
-
+                deferrable=self.deferrable, initially=self.initially,
+                )
+            self.constraint._elements[self.parent] = self
+            self.constraint._set_parent(table)
+        table.foreign_keys.add(self)
+        
 class DefaultGenerator(SchemaItem):
     """Base class for column *default* values."""
 
     __visit_name__ = 'default_generator'
 
-    def __init__(self, for_update=False, metadata=None):
+    def __init__(self, for_update=False):
         self.for_update = for_update
-        self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
 
     def _set_parent(self, column):
         self.column = column
-        self.metadata = self.column.table.metadata
         if self.for_update:
             self.column.onupdate = self
         else:
@@ -989,6 +1016,14 @@ class DefaultGenerator(SchemaItem):
             bind = _bind_or_error(self)
         return bind._execute_default(self, **kwargs)
 
+    @property
+    def bind(self):
+        """Return the connectable associated with this default."""
+        if getattr(self, 'column', None):
+            return self.column.table.bind
+        else:
+            return None
+
     def __repr__(self):
         return "DefaultGenerator()"
 
@@ -1026,7 +1061,9 @@ class ColumnDefault(DefaultGenerator):
             return lambda ctx: fn()
 
         positionals = len(argspec[0])
-        if inspect.ismethod(inspectable):
+        
+        # Py3K compat - no unbound methods
+        if inspect.ismethod(inspectable) or inspect.isclass(fn):
             positionals -= 1
 
         if positionals == 0:
@@ -1047,7 +1084,7 @@ class ColumnDefault(DefaultGenerator):
     __visit_name__ = property(_visit_name)
 
     def __repr__(self):
-        return "ColumnDefault(%s)" % repr(self.arg)
+        return "ColumnDefault(%r)" % self.arg
 
 class Sequence(DefaultGenerator):
     """Represents a named database sequence."""
@@ -1055,15 +1092,15 @@ class Sequence(DefaultGenerator):
     __visit_name__ = 'sequence'
 
     def __init__(self, name, start=None, increment=None, schema=None,
-                 optional=False, quote=None, **kwargs):
-        super(Sequence, self).__init__(**kwargs)
+                 optional=False, quote=None, metadata=None, for_update=False):
+        super(Sequence, self).__init__(for_update=for_update)
         self.name = name
         self.start = start
         self.increment = increment
         self.optional = optional
         self.quote = quote
         self.schema = schema
-        self.kwargs = kwargs
+        self.metadata = metadata
 
     def __repr__(self):
         return "Sequence(%s)" % ', '.join(
@@ -1074,7 +1111,19 @@ class Sequence(DefaultGenerator):
     def _set_parent(self, column):
         super(Sequence, self)._set_parent(column)
         column.sequence = self
-
+        
+        column._on_table_attach(self._set_table)
+    
+    def _set_table(self, table):
+        self.metadata = table.metadata
+        
+    @property
+    def bind(self):
+        if self.metadata:
+            return self.metadata.bind
+        else:
+            return None
+        
     def create(self, bind=None, checkfirst=True):
         """Creates this sequence in the database."""
 
@@ -1123,17 +1172,12 @@ class DefaultClause(FetchedValue):
 # alias; deprecated starting 0.5.0
 PassiveDefault = DefaultClause
 
-
 class Constraint(SchemaItem):
-    """A table-level SQL constraint, such as a KEY.
-
-    Implements a hybrid of dict/setlike behavior with regards to the list of
-    underying columns.
-    """
+    """A table-level SQL constraint."""
 
     __visit_name__ = 'constraint'
 
-    def __init__(self, name=None, deferrable=None, initially=None):
+    def __init__(self, name=None, deferrable=None, initially=None, inline_ddl=True):
         """Create a SQL constraint.
 
         name
@@ -1146,33 +1190,87 @@ class Constraint(SchemaItem):
         initially
           Optional string.  If set, emit INITIALLY <value> when issuing DDL
           for this constraint.
+          
+        inline_ddl
+          if True, DDL for this Constraint will be generated within the span of a
+          CREATE TABLE or DROP TABLE statement, when the associated table's
+          DDL is generated.  if False, no DDL is issued within that process.
+          Instead, it is expected that an AddConstraint or DropConstraint 
+          construct will be used to issue DDL for this Contraint.
+          The AddConstraint/DropConstraint constructs set this flag automatically
+          as well.
         """
 
         self.name = name
-        self.columns = expression.ColumnCollection()
         self.deferrable = deferrable
         self.initially = initially
+        self.inline_ddl = inline_ddl
+
+    @property
+    def table(self):
+        if isinstance(self.parent, Table):
+            return self.parent
+        else:
+            raise exc.InvalidRequestError("This constraint is not bound to a table.")
+
+    def _set_parent(self, parent):
+        self.parent = parent
+        parent.constraints.add(self)
+
+    def copy(self, **kw):
+        raise NotImplementedError()
+
+class ColumnCollectionConstraint(Constraint):
+    """A constraint that proxies a ColumnCollection."""
+    
+    def __init__(self, *columns, **kw):
+        """
+        \*columns
+          A sequence of column names or Column objects.
+
+        name
+          Optional, the in-database name of this constraint.
+
+        deferrable
+          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
+          issuing DDL for this constraint.
+
+        initially
+          Optional string.  If set, emit INITIALLY <value> when issuing DDL
+          for this constraint.
+        
+        """
+        super(ColumnCollectionConstraint, self).__init__(**kw)
+        self.columns = expression.ColumnCollection()
+        self._pending_colargs = [_to_schema_column_or_string(c) for c in columns]
+        if self._pending_colargs and \
+                isinstance(self._pending_colargs[0], Column) and \
+                self._pending_colargs[0].table is not None:
+            self._set_parent(self._pending_colargs[0].table)
+        
+    def _set_parent(self, table):
+        super(ColumnCollectionConstraint, self)._set_parent(table)
+        for col in self._pending_colargs:
+            if isinstance(col, basestring):
+                col = table.c[col]
+            self.columns.add(col)
 
     def __contains__(self, x):
         return x in self.columns
 
+    def copy(self, **kw):
+        return self.__class__(name=self.name, deferrable=self.deferrable,
+                              initially=self.initially, *self.columns.keys())
+
     def contains_column(self, col):
         return self.columns.contains_column(col)
 
-    def keys(self):
-        return self.columns.keys()
-
-    def __add__(self, other):
-        return self.columns + other
-
     def __iter__(self):
         return iter(self.columns)
 
     def __len__(self):
         return len(self.columns)
 
-    def copy(self, **kw):
-        raise NotImplementedError()
 
 class CheckConstraint(Constraint):
     """A table- or column-level CHECK constraint.
@@ -1180,7 +1278,7 @@ class CheckConstraint(Constraint):
     Can be included in the definition of a Table or Column.
     """
 
-    def __init__(self, sqltext, name=None, deferrable=None, initially=None):
+    def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None):
         """Construct a CHECK constraint.
 
         sqltext
@@ -1197,6 +1295,7 @@ class CheckConstraint(Constraint):
         initially
           Optional string.  If set, emit INITIALLY <value> when issuing DDL
           for this constraint.
+          
         """
 
         super(CheckConstraint, self).__init__(name, deferrable, initially)
@@ -1204,7 +1303,9 @@ class CheckConstraint(Constraint):
             raise exc.ArgumentError(
                 "sqltext must be a string and will be used verbatim.")
         self.sqltext = sqltext
-
+        if table:
+            self._set_parent(table)
+            
     def __visit_name__(self):
         if isinstance(self.parent, Table):
             return "check_constraint"
@@ -1212,10 +1313,6 @@ class CheckConstraint(Constraint):
             return "column_check_constraint"
     __visit_name__ = property(__visit_name__)
 
-    def _set_parent(self, parent):
-        self.parent = parent
-        parent.constraints.add(self)
-
     def copy(self, **kw):
         return CheckConstraint(self.sqltext, name=self.name)
 
@@ -1232,7 +1329,8 @@ class ForeignKeyConstraint(Constraint):
     """
     __visit_name__ = 'foreign_key_constraint'
 
-    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None, link_to_name=False):
+    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, 
+                    deferrable=None, initially=None, use_alter=False, link_to_name=False, table=None):
         """Construct a composite-capable FOREIGN KEY.
 
         :param columns: A sequence of local column names.  The named columns must be defined
@@ -1261,42 +1359,72 @@ class ForeignKeyConstraint(Constraint):
         :param link_to_name: if True, the string name given in ``column`` is the rendered
           name of the referenced column, not its locally assigned ``key``.
 
-        :param use_alter: If True, do not emit this key as part of the CREATE TABLE
-          definition.  Instead, use ALTER TABLE after table creation to add
-          the key.  Useful for circular dependencies.
+        :param use_alter: If True, do not emit the DDL for this constraint
+          as part of the CREATE TABLE definition.  Instead, generate it via an 
+          ALTER TABLE statement issued after the full collection of tables have been 
+          created, and drop it via an ALTER TABLE statement before the full collection 
+          of tables are dropped.   This is shorthand for the usage of 
+          :class:`AddConstraint` and :class:`DropConstraint` applied as "after-create"
+          and "before-drop" events on the MetaData object.   This is normally used to
+          generate/drop constraints on objects that are mutually dependent on each other.
           
         """
         super(ForeignKeyConstraint, self).__init__(name, deferrable, initially)
-        self.__colnames = columns
-        self.__refcolnames = refcolumns
-        self.elements = util.OrderedSet()
+
         self.onupdate = onupdate
         self.ondelete = ondelete
         self.link_to_name = link_to_name
         if self.name is None and use_alter:
-            raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
+            raise exc.ArgumentError("Alterable Constraint requires a name")
         self.use_alter = use_alter
 
+        self._elements = util.OrderedDict()
+        for col, refcol in zip(columns, refcolumns):
+            self._elements[col] = ForeignKey(
+                    refcol, 
+                    constraint=self, 
+                    name=self.name, 
+                    onupdate=self.onupdate, 
+                    ondelete=self.ondelete, 
+                    use_alter=self.use_alter, 
+                    link_to_name=self.link_to_name
+                )
+
+        if table:
+            self._set_parent(table)
+    
+    @property
+    def columns(self):
+        return self._elements.keys()
+        
+    @property
+    def elements(self):
+        return self._elements.values()
+        
     def _set_parent(self, table):
-        self.table = table
-        if self not in table.constraints:
-            table.constraints.add(self)
-            for (c, r) in zip(self.__colnames, self.__refcolnames):
-                self.append_element(c, r)
-
-    def append_element(self, col, refcol):
-        fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter, link_to_name=self.link_to_name)
-        fk._set_parent(self.table.c[col])
-        self._append_fk(fk)
-
-    def _append_fk(self, fk):
-        self.columns.add(self.table.c[fk.parent.key])
-        self.elements.add(fk)
-
+        super(ForeignKeyConstraint, self)._set_parent(table)
+        for col, fk in self._elements.iteritems():
+            if isinstance(col, basestring):
+                col = table.c[col]
+            fk._set_parent(col)
+            
+        if self.use_alter:
+            def supports_alter(event, schema_item, bind, **kw):
+                return table in set(kw['tables']) and bind.dialect.supports_alter
+            AddConstraint(self, on=supports_alter).execute_at('after-create', table.metadata)
+            DropConstraint(self, on=supports_alter).execute_at('before-drop', table.metadata)
+            
     def copy(self, **kw):
-        return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec(**kw) for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
-
-class PrimaryKeyConstraint(Constraint):
+        return ForeignKeyConstraint(
+                    [x.parent.name for x in self._elements.values()], 
+                    [x._get_colspec(**kw) for x in self._elements.values()], 
+                    name=self.name, 
+                    onupdate=self.onupdate, 
+                    ondelete=self.ondelete, 
+                    use_alter=self.use_alter
+                )
+
+class PrimaryKeyConstraint(ColumnCollectionConstraint):
     """A table-level PRIMARY KEY constraint.
 
     Defines a single column or composite PRIMARY KEY constraint. For a
@@ -1307,63 +1435,14 @@ class PrimaryKeyConstraint(Constraint):
 
     __visit_name__ = 'primary_key_constraint'
 
-    def __init__(self, *columns, **kwargs):
-        """Construct a composite-capable PRIMARY KEY.
-
-        \*columns
-          A sequence of column names.  All columns named must be defined and
-          present within the parent Table.
-
-        name
-          Optional, the in-database name of the key.
-
-        deferrable
-          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
-          issuing DDL for this constraint.
-
-        initially
-          Optional string.  If set, emit INITIALLY <value> when issuing DDL
-          for this constraint.
-        """
-
-        constraint_args = dict(name=kwargs.pop('name', None),
-                               deferrable=kwargs.pop('deferrable', None),
-                               initially=kwargs.pop('initially', None))
-        if kwargs:
-            raise exc.ArgumentError(
-                'Unknown PrimaryKeyConstraint argument(s): %s' %
-                ', '.join(repr(x) for x in kwargs.keys()))
-
-        super(PrimaryKeyConstraint, self).__init__(**constraint_args)
-        self.__colnames = list(columns)
-
     def _set_parent(self, table):
-        self.table = table
-        table.primary_key = self
-        for name in self.__colnames:
-            self.add(table.c[name])
-
-    def add(self, col):
-        self.columns.add(col)
-        col.primary_key = True
-    append_column = add
+        super(PrimaryKeyConstraint, self)._set_parent(table)
+        table._set_primary_key(self)
 
-    def replace(self, col):
+    def _replace(self, col):
         self.columns.replace(col)
 
-    def remove(self, col):
-        col.primary_key = False
-        del self.columns[col.key]
-
-    def copy(self, **kw):
-        return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
-
-    __hash__ = Constraint.__hash__
-    
-    def __eq__(self, other):
-        return self.columns == other
-
-class UniqueConstraint(Constraint):
+class UniqueConstraint(ColumnCollectionConstraint):
     """A table-level UNIQUE constraint.
 
     Defines a single column or composite UNIQUE constraint. For a no-frills,
@@ -1374,48 +1453,6 @@ class UniqueConstraint(Constraint):
 
     __visit_name__ = 'unique_constraint'
 
-    def __init__(self, *columns, **kwargs):
-        """Construct a UNIQUE constraint.
-
-        \*columns
-          A sequence of column names.  All columns named must be defined and
-          present within the parent Table.
-
-        name
-          Optional, the in-database name of the key.
-
-        deferrable
-          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
-          issuing DDL for this constraint.
-
-        initially
-          Optional string.  If set, emit INITIALLY <value> when issuing DDL
-          for this constraint.
-        """
-
-        constraint_args = dict(name=kwargs.pop('name', None),
-                               deferrable=kwargs.pop('deferrable', None),
-                               initially=kwargs.pop('initially', None))
-        if kwargs:
-            raise exc.ArgumentError(
-                'Unknown UniqueConstraint argument(s): %s' %
-                ', '.join(repr(x) for x in kwargs.keys()))
-
-        super(UniqueConstraint, self).__init__(**constraint_args)
-        self.__colnames = list(columns)
-
-    def _set_parent(self, table):
-        self.table = table
-        table.constraints.add(self)
-        for c in self.__colnames:
-            self.append_column(table.c[c])
-
-    def append_column(self, col):
-        self.columns.add(col)
-
-    def copy(self, **kw):
-        return UniqueConstraint(name=self.name, *self.__colnames)
-
 class Index(SchemaItem):
     """A table-level INDEX.
 
@@ -1436,7 +1473,7 @@ class Index(SchemaItem):
 
         \*columns
           Columns to include in the index. All columns must belong to the same
-          table, and no column may appear more than once.
+          table.
 
         \**kwargs
           Keyword arguments include:
@@ -1444,42 +1481,36 @@ class Index(SchemaItem):
           unique
             Defaults to False: create a unique index.
 
-          postgres_where
+          postgresql_where
             Defaults to None: create a partial index when using PostgreSQL
         """
 
         self.name = name
-        self.columns = []
+        self.columns = expression.ColumnCollection()
         self.table = None
         self.unique = kwargs.pop('unique', False)
         self.kwargs = kwargs
 
-        self._init_items(*columns)
-
-    def _init_items(self, *args):
-        for column in args:
-            self.append_column(_to_schema_column(column))
+        for column in columns:
+            column = _to_schema_column(column)
+            if self.table is None:
+                self._set_parent(column.table)
+            elif column.table != self.table:
+                # all columns muse be from same table
+                raise exc.ArgumentError(
+                    "All index columns must be from same table. "
+                    "%s is from %s not %s" % (column, column.table, self.table))
+            self.columns.add(column)
 
     def _set_parent(self, table):
         self.table = table
-        self.metadata = table.metadata
         table.indexes.add(self)
 
-    def append_column(self, column):
-        # make sure all columns are from the same table
-        # and no column is repeated
-        if self.table is None:
-            self._set_parent(column.table)
-        elif column.table != self.table:
-            # all columns muse be from same table
-            raise exc.ArgumentError(
-                "All index columns must be from same table. "
-                "%s is from %s not %s" % (column, column.table, self.table))
-        elif column.name in [ c.name for c in self.columns ]:
-            raise exc.ArgumentError(
-                "A column may not appear twice in the "
-                "same index (%s already has column %s)" % (self.name, column))
-        self.columns.append(column)
+    @property
+    def bind(self):
+        """Return the connectable associated with this Index."""
+        
+        return self.table.bind
 
     def create(self, bind=None):
         if bind is None:
@@ -1492,9 +1523,6 @@ class Index(SchemaItem):
             bind = _bind_or_error(self)
         bind.drop(self)
 
-    def __str__(self):
-        return repr(self)
-
     def __repr__(self):
         return 'Index("%s", %s%s)' % (self.name,
                                       ', '.join(repr(c) for c in self.columns),
@@ -1576,27 +1604,6 @@ class MetaData(SchemaItem):
 
         return self._bind is not None
 
-    @util.deprecated('Deprecated. Use ``metadata.bind = <engine>`` or '
-                     '``metadata.bind = <url>``.')
-    def connect(self, bind, **kwargs):
-        """Bind this MetaData to an Engine.
-
-        bind
-          A string, ``URL``, ``Engine`` or ``Connection`` instance.  If a
-          string or ``URL``, will be passed to ``create_engine()`` along with
-          ``\**kwargs`` to produce the engine which to connect to.  Otherwise
-          connects directly to the given ``Engine``.
-          
-        """
-        global URL
-        if URL is None:
-            from sqlalchemy.engine.url import URL
-        if isinstance(bind, (basestring, URL)):
-            from sqlalchemy import create_engine
-            self._bind = create_engine(bind, **kwargs)
-        else:
-            self._bind = bind
-
     def bind(self):
         """An Engine or Connection to which this MetaData is bound.
 
@@ -1633,27 +1640,13 @@ class MetaData(SchemaItem):
         # TODO: scan all other tables and remove FK _column
         del self.tables[table.key]
 
-    @util.deprecated('Deprecated. Use ``metadata.sorted_tables``')
-    def table_iterator(self, reverse=True, tables=None):
-        """Deprecated - use metadata.sorted_tables()."""
-        
-        from sqlalchemy.sql.util import sort_tables
-        if tables is None:
-            tables = self.tables.values()
-        else:
-            tables = set(tables).intersection(self.tables.values())
-        ret = sort_tables(tables)
-        if reverse:
-            ret = reversed(ret)
-        return iter(ret)
-    
     @property
     def sorted_tables(self):
         """Returns a list of ``Table`` objects sorted in order of
         dependency.
         """
         from sqlalchemy.sql.util import sort_tables
-        return sort_tables(self.tables.values())
+        return sort_tables(self.tables.itervalues())
         
     def reflect(self, bind=None, schema=None, only=None):
         """Load all available table definitions from the database.
@@ -1699,7 +1692,7 @@ class MetaData(SchemaItem):
 
         available = util.OrderedSet(bind.engine.table_names(schema,
                                                             connection=conn))
-        current = set(self.tables.keys())
+        current = set(self.tables.iterkeys())
 
         if only is None:
             load = [name for name in available if name not in current]
@@ -1777,11 +1770,7 @@ class MetaData(SchemaItem):
         """
         if bind is None:
             bind = _bind_or_error(self)
-        for listener in self.ddl_listeners['before-create']:
-            listener('before-create', self, bind)
         bind.create(self, checkfirst=checkfirst, tables=tables)
-        for listener in self.ddl_listeners['after-create']:
-            listener('after-create', self, bind)
 
     def drop_all(self, bind=None, tables=None, checkfirst=True):
         """Drop all tables stored in this metadata.
@@ -1804,11 +1793,7 @@ class MetaData(SchemaItem):
         """
         if bind is None:
             bind = _bind_or_error(self)
-        for listener in self.ddl_listeners['before-drop']:
-            listener('before-drop', self, bind)
         bind.drop(self, checkfirst=checkfirst, tables=tables)
-        for listener in self.ddl_listeners['after-drop']:
-            listener('after-drop', self, bind)
 
 class ThreadLocalMetaData(MetaData):
     """A MetaData variant that presents a different ``bind`` in every thread.
@@ -1833,31 +1818,6 @@ class ThreadLocalMetaData(MetaData):
         self.__engines = {}
         super(ThreadLocalMetaData, self).__init__()
 
-    @util.deprecated('Deprecated. Use ``metadata.bind = <engine>`` or '
-                     '``metadata.bind = <url>``.')
-    def connect(self, bind, **kwargs):
-        """Bind to an Engine in the caller's thread.
-
-        bind
-          A string, ``URL``, ``Engine`` or ``Connection`` instance.  If a
-          string or ``URL``, will be passed to ``create_engine()`` along with
-          ``\**kwargs`` to produce the engine which to connect to.  Otherwise
-          connects directly to the given ``Engine``.
-        """
-
-        global URL
-        if URL is None:
-            from sqlalchemy.engine.url import URL
-
-        if isinstance(bind, (basestring, URL)):
-            try:
-                engine = self.__engines[bind]
-            except KeyError:
-                from sqlalchemy import create_engine
-                engine = create_engine(bind, **kwargs)
-            bind = engine
-        self._bind_to(bind)
-
     def bind(self):
         """The bound Engine or Connection for this thread.
 
@@ -1899,7 +1859,7 @@ class ThreadLocalMetaData(MetaData):
     def dispose(self):
         """Dispose all bound engines, in all thread contexts."""
 
-        for e in self.__engines.values():
+        for e in self.__engines.itervalues():
             if hasattr(e, 'dispose'):
                 e.dispose()
 
@@ -1909,87 +1869,15 @@ class SchemaVisitor(visitors.ClauseVisitor):
     __traverse_options__ = {'schema_visitor':True}
 
 
-class DDL(object):
-    """A literal DDL statement.
-
-    Specifies literal SQL DDL to be executed by the database.  DDL objects can
-    be attached to ``Tables`` or ``MetaData`` instances, conditionally
-    executing SQL as part of the DDL lifecycle of those schema items.  Basic
-    templating support allows a single DDL instance to handle repetitive tasks
-    for multiple tables.
-
-    Examples::
-
-      tbl = Table('users', metadata, Column('uid', Integer)) # ...
-      DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl)
-
-      spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb')
-      spow.execute_at('after-create', tbl)
-
-      drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
-      connection.execute(drop_spow)
-    """
-
-    def __init__(self, statement, on=None, context=None, bind=None):
-        """Create a DDL statement.
-
-        statement
-          A string or unicode string to be executed.  Statements will be
-          processed with Python's string formatting operator.  See the
-          ``context`` argument and the ``execute_at`` method.
-
-          A literal '%' in a statement must be escaped as '%%'.
-
-          SQL bind parameters are not available in DDL statements.
-
-        on
-          Optional filtering criteria.  May be a string or a callable
-          predicate.  If a string, it will be compared to the name of the
-          executing database dialect::
-
-            DDL('something', on='postgres')
-
-          If a callable, it will be invoked with three positional arguments:
-
-            event
-              The name of the event that has triggered this DDL, such as
-              'after-create' Will be None if the DDL is executed explicitly.
-
-            schema_item
-              A SchemaItem instance, such as ``Table`` or ``MetaData``. May be
-              None if the DDL is executed explicitly.
-
-            connection
-              The ``Connection`` being used for DDL execution
-
-          If the callable returns a true value, the DDL statement will be
-          executed.
-
-        context
-          Optional dictionary, defaults to None.  These values will be
-          available for use in string substitutions on the DDL statement.
-
-        bind
-          Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()``
-          is invoked without a bind argument.
-          
-        """
-
-        if not isinstance(statement, basestring):
-            raise exc.ArgumentError(
-                "Expected a string or unicode SQL statement, got '%r'" %
-                statement)
-        if (on is not None and
-            (not isinstance(on, basestring) and not util.callable(on))):
-            raise exc.ArgumentError(
-                "Expected the name of a database dialect or a callable for "
-                "'on' criteria, got type '%s'." % type(on).__name__)
-
-        self.statement = statement
-        self.on = on
-        self.context = context or {}
-        self._bind = bind
+class DDLElement(expression.ClauseElement):
+    """Base class for DDL expression constructs."""
+    
+    supports_execution = True
+    _autocommit = True
 
+    schema_item = None
+    on = None
+    
     def execute(self, bind=None, schema_item=None):
         """Execute this DDL immediately.
 
@@ -2010,10 +1898,9 @@ class DDL(object):
 
         if bind is None:
             bind = _bind_or_error(self)
-        # no SQL bind params are supported
+
         if self._should_execute(None, schema_item, bind):
-            executable = expression.text(self._expand(schema_item, bind))
-            return bind.execute(executable)
+            return bind.execute(self.against(schema_item))
         else:
             bind.engine.logger.info("DDL execution skipped, criteria not met.")
 
@@ -2025,7 +1912,7 @@ class DDL(object):
         will be executed using the same Connection and transactional context
         as the Table create/drop itself.  The ``.bind`` property of this
         statement is ignored.
-
+        
         event
           One of the events defined in the schema item's ``.ddl_events``;
           e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop'
@@ -2066,65 +1953,143 @@ class DDL(object):
         schema_item.ddl_listeners[event].append(self)
         return self
 
-    def bind(self):
-        """An Engine or Connection to which this DDL is bound.
-
-        This property may be assigned an ``Engine`` or ``Connection``, or
-        assigned a string or URL to automatically create a basic ``Engine``
-        for this bind with ``create_engine()``.
-        """
-        return self._bind
-
-    def _bind_to(self, bind):
-        """Bind this MetaData to an Engine, Connection, string or URL."""
-
-        global URL
-        if URL is None:
-            from sqlalchemy.engine.url import URL
+    @expression._generative
+    def against(self, schema_item):
+        """Return a copy of this DDL against a specific schema item."""
 
-        if isinstance(bind, (basestring, URL)):
-            from sqlalchemy import create_engine
-            self._bind = create_engine(bind)
-        else:
-            self._bind = bind
-    bind = property(bind, _bind_to)
+        self.schema_item = schema_item
 
-    def __call__(self, event, schema_item, bind):
+    def __call__(self, event, schema_item, bind, **kw):
         """Execute the DDL as a ddl_listener."""
 
-        if self._should_execute(event, schema_item, bind):
-            statement = expression.text(self._expand(schema_item, bind))
-            return bind.execute(statement)
+        if self._should_execute(event, schema_item, bind, **kw):
+            return bind.execute(self.against(schema_item))
 
-    def _expand(self, schema_item, bind):
-        return self.statement % self._prepare_context(schema_item, bind)
+    def _check_ddl_on(self, on):
+        if (on is not None and
+            (not isinstance(on, (basestring, tuple, list, set)) and not util.callable(on))):
+            raise exc.ArgumentError(
+                "Expected the name of a database dialect, a tuple of names, or a callable for "
+                "'on' criteria, got type '%s'." % type(on).__name__)
 
-    def _should_execute(self, event, schema_item, bind):
+    def _should_execute(self, event, schema_item, bind, **kw):
         if self.on is None:
             return True
         elif isinstance(self.on, basestring):
             return self.on == bind.engine.name
+        elif isinstance(self.on, (tuple, list, set)):
+            return bind.engine.name in self.on
         else:
-            return self.on(event, schema_item, bind)
+            return self.on(event, schema_item, bind, **kw)
 
-    def _prepare_context(self, schema_item, bind):
-        # table events can substitute table and schema name
-        if isinstance(schema_item, Table):
-            context = self.context.copy()
+    def bind(self):
+        if self._bind:
+            return self._bind
+    def _set_bind(self, bind):
+        self._bind = bind
+    bind = property(bind, _set_bind)
 
-            preparer = bind.dialect.identifier_preparer
-            path = preparer.format_table_seq(schema_item)
-            if len(path) == 1:
-                table, schema = path[0], ''
-            else:
-                table, schema = path[-1], path[0]
+    def _generate(self):
+        s = self.__class__.__new__(self.__class__)
+        s.__dict__ = self.__dict__.copy()
+        return s
+    
+    def _compiler(self, dialect, **kw):
+        """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+        
+        return dialect.ddl_compiler(dialect, self, **kw)
+
+class DDL(DDLElement):
+    """A literal DDL statement.
+
+    Specifies literal SQL DDL to be executed by the database.  DDL objects can
+    be attached to ``Tables`` or ``MetaData`` instances, conditionally
+    executing SQL as part of the DDL lifecycle of those schema items.  Basic
+    templating support allows a single DDL instance to handle repetitive tasks
+    for multiple tables.
+
+    Examples::
+
+      tbl = Table('users', metadata, Column('uid', Integer)) # ...
+      DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl)
+
+      spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb')
+      spow.execute_at('after-create', tbl)
+
+      drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
+      connection.execute(drop_spow)
+    """
+
+    __visit_name__ = "ddl"
+    
+    def __init__(self, statement, on=None, context=None, bind=None):
+        """Create a DDL statement.
+
+        statement
+          A string or unicode string to be executed.  Statements will be
+          processed with Python's string formatting operator.  See the
+          ``context`` argument and the ``execute_at`` method.
+
+          A literal '%' in a statement must be escaped as '%%'.
+
+          SQL bind parameters are not available in DDL statements.
+
+        on
+          Optional filtering criteria.  May be a string, tuple or a callable
+          predicate.  If a string, it will be compared to the name of the
+          executing database dialect::
+
+            DDL('something', on='postgresql')
+        
+          If a tuple, specifies multiple dialect names:
+          
+            DDL('something', on=('postgresql', 'mysql'))
+            
+          If a callable, it will be invoked with three positional arguments
+          as well as optional keyword arguments:
+
+            event
+              The name of the event that has triggered this DDL, such as
+              'after-create' Will be None if the DDL is executed explicitly.
+
+            schema_item
+              A SchemaItem instance, such as ``Table`` or ``MetaData``. May be
+              None if the DDL is executed explicitly.
+
+            connection
+              The ``Connection`` being used for DDL execution
+
+            **kw
+              Keyword arguments which may be sent include:
+                tables - a list of Table objects which are to be created/
+                dropped within a MetaData.create_all() or drop_all() method
+                call.
+              
+          If the callable returns a true value, the DDL statement will be
+          executed.
+
+        context
+          Optional dictionary, defaults to None.  These values will be
+          available for use in string substitutions on the DDL statement.
+
+        bind
+          Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()``
+          is invoked without a bind argument.
+          
+        """
+
+        if not isinstance(statement, basestring):
+            raise exc.ArgumentError(
+                "Expected a string or unicode SQL statement, got '%r'" %
+                statement)
+
+        self.statement = statement
+        self.context = context or {}
+
+        self._check_ddl_on(on)
+        self.on = on
+        self._bind = bind
 
-            context.setdefault('table', table)
-            context.setdefault('schema', schema)
-            context.setdefault('fullname', preparer.format_table(schema_item))
-            return context
-        else:
-            return self.context
 
     def __repr__(self):
         return '<%s@%s; %s>' % (
@@ -2135,12 +2100,81 @@ class DDL(object):
                        if getattr(self, key)]))
 
 def _to_schema_column(element):
-    if hasattr(element, '__clause_element__'):
-        element = element.__clause_element__()
-    if not isinstance(element, Column):
-        raise exc.ArgumentError("schema.Column object expected")
-    return element
+   if hasattr(element, '__clause_element__'):
+       element = element.__clause_element__()
+   if not isinstance(element, Column):
+       raise exc.ArgumentError("schema.Column object expected")
+   return element
+
+def _to_schema_column_or_string(element):
+  if hasattr(element, '__clause_element__'):
+      element = element.__clause_element__()
+  return element
+
+class _CreateDropBase(DDLElement):
+    """Base class for DDL constucts that represent CREATE and DROP or equivalents.
+
+    The common theme of _CreateDropBase is a single
+    ``element`` attribute which refers to the element
+    to be created or dropped.
     
+    """
+    
+    def __init__(self, element, on=None, bind=None):
+        self.element = element
+        self._check_ddl_on(on)
+        self.on = on
+        self.bind = bind
+        element.inline_ddl = False
+
+class CreateTable(_CreateDropBase):
+    """Represent a CREATE TABLE statement."""
+    
+    __visit_name__ = "create_table"
+    
+class DropTable(_CreateDropBase):
+    """Represent a DROP TABLE statement."""
+
+    __visit_name__ = "drop_table"
+
+    def __init__(self, element, cascade=False, **kw):
+        self.cascade = cascade
+        super(DropTable, self).__init__(element, **kw)
+
+class CreateSequence(_CreateDropBase):
+    """Represent a CREATE SEQUENCE statement."""
+    
+    __visit_name__ = "create_sequence"
+
+class DropSequence(_CreateDropBase):
+    """Represent a DROP SEQUENCE statement."""
+
+    __visit_name__ = "drop_sequence"
+    
+class CreateIndex(_CreateDropBase):
+    """Represent a CREATE INDEX statement."""
+    
+    __visit_name__ = "create_index"
+
+class DropIndex(_CreateDropBase):
+    """Represent a DROP INDEX statement."""
+
+    __visit_name__ = "drop_index"
+
+class AddConstraint(_CreateDropBase):
+    """Represent an ALTER TABLE ADD CONSTRAINT statement."""
+    
+    __visit_name__ = "add_constraint"
+
+class DropConstraint(_CreateDropBase):
+    """Represent an ALTER TABLE DROP CONSTRAINT statement."""
+
+    __visit_name__ = "drop_constraint"
+    
+    def __init__(self, element, cascade=False, **kw):
+        self.cascade = cascade
+        super(DropConstraint, self).__init__(element, **kw)
+
 def _bind_or_error(schemaitem):
     bind = schemaitem.bind
     if not bind:
index 6af65ec140b36cb729118079bfee8de5e56c8193..6bfad4a76cb07909dd533dab2fde6054f3133e04 100644 (file)
@@ -6,19 +6,23 @@
 
 """Base SQL and DDL compiler implementations.
 
-Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is
-responsible for generating all SQL query strings, as well as
-:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper`
-which issue CREATE and DROP DDL for tables, sequences, and indexes.
-
-The elements in this module are used by public-facing constructs like
-:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`.
-While dialect authors will want to be familiar with this module for the purpose of
-creating database-specific compilers and schema generators, the module
-is otherwise internal to SQLAlchemy.
+Classes provided include:
+
+:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL
+strings
+
+:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL
+(data definition language) strings
+
+:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
+type specification strings.
+
+To generate user-defined SQL strings, see 
+:module:`~sqlalchemy.ext.compiler`.
+
 """
 
-import string, re
+import re
 from sqlalchemy import schema, engine, util, exc
 from sqlalchemy.sql import operators, functions, util as sql_util, visitors
 from sqlalchemy.sql import expression as sql
@@ -58,40 +62,43 @@ BIND_TEMPLATES = {
 
 
 OPERATORS =  {
-    operators.and_ : 'AND',
-    operators.or_ : 'OR',
-    operators.inv : 'NOT',
-    operators.add : '+',
-    operators.mul : '*',
-    operators.sub : '-',
-    operators.div : '/',
-    operators.mod : '%',
-    operators.truediv : '/',
-    operators.lt : '<',
-    operators.le : '<=',
-    operators.ne : '!=',
-    operators.gt : '>',
-    operators.ge : '>=',
-    operators.eq : '=',
-    operators.distinct_op : 'DISTINCT',
-    operators.concat_op : '||',
-    operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.between_op : 'BETWEEN',
-    operators.match_op : 'MATCH',
-    operators.in_op : 'IN',
-    operators.notin_op : 'NOT IN',
+    # binary
+    operators.and_ : ' AND ',
+    operators.or_ : ' OR ',
+    operators.add : ' + ',
+    operators.mul : ' * ',
+    operators.sub : ' - ',
+    # Py2K
+    operators.div : ' / ',
+    # end Py2K
+    operators.mod : ' % ',
+    operators.truediv : ' / ',
+    operators.lt : ' < ',
+    operators.le : ' <= ',
+    operators.ne : ' != ',
+    operators.gt : ' > ',
+    operators.ge : ' >= ',
+    operators.eq : ' = ',
+    operators.concat_op : ' || ',
+    operators.between_op : ' BETWEEN ',
+    operators.match_op : ' MATCH ',
+    operators.in_op : ' IN ',
+    operators.notin_op : ' NOT IN ',
     operators.comma_op : ', ',
-    operators.desc_op : 'DESC',
-    operators.asc_op : 'ASC',
-    operators.from_ : 'FROM',
-    operators.as_ : 'AS',
-    operators.exists : 'EXISTS',
-    operators.is_ : 'IS',
-    operators.isnot : 'IS NOT',
-    operators.collate : 'COLLATE',
+    operators.from_ : ' FROM ',
+    operators.as_ : ' AS ',
+    operators.is_ : ' IS ',
+    operators.isnot : ' IS NOT ',
+    operators.collate : ' COLLATE ',
+
+    # unary
+    operators.exists : 'EXISTS ',
+    operators.distinct_op : 'DISTINCT ',
+    operators.inv : 'NOT ',
+
+    # modifiers
+    operators.desc_op : ' DESC',
+    operators.asc_op : ' ASC',
 }
 
 FUNCTIONS = {
@@ -140,7 +147,7 @@ class _CompileLabel(visitors.Visitable):
     def quote(self):
         return self.element.quote
 
-class DefaultCompiler(engine.Compiled):
+class SQLCompiler(engine.Compiled):
     """Default implementation of Compiled.
 
     Compiles ClauseElements into SQL strings.   Uses a similar visit
@@ -148,14 +155,14 @@ class DefaultCompiler(engine.Compiled):
 
     """
 
-    operators = OPERATORS
-    functions = FUNCTIONS
     extract_map = EXTRACT_MAP
 
-    # if we are insert/update/delete. 
-    # set to true when we visit an INSERT, UPDATE or DELETE
+    # class-level defaults which can be set at the instance
+    # level to define if this Compiled instance represents
+    # INSERT/UPDATE/DELETE
     isdelete = isinsert = isupdate = False
-
+    returning = None
+    
     def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
 
@@ -170,7 +177,9 @@ class DefaultCompiler(engine.Compiled):
           statement.
 
         """
-        engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs)
+        engine.Compiled.__init__(self, dialect, statement, **kwargs)
+
+        self.column_keys = column_keys
 
         # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
         self.inline = inline or getattr(statement, 'inline', False)
@@ -210,12 +219,6 @@ class DefaultCompiler(engine.Compiled):
         # or dialect.max_identifier_length
         self.truncated_names = {}
 
-    def compile(self):
-        self.string = self.process(self.statement)
-
-    def process(self, obj, **kwargs):
-        return obj._compiler_dispatch(self, **kwargs)
-
     def is_subquery(self):
         return len(self.stack) > 1
 
@@ -223,7 +226,6 @@ class DefaultCompiler(engine.Compiled):
         """return a dictionary of bind parameter keys and values"""
 
         if params:
-            params = util.column_dict(params)
             pd = {}
             for bindparam, name in self.bind_names.iteritems():
                 for paramname in (bindparam.key, bindparam.shortname, name):
@@ -245,7 +247,10 @@ class DefaultCompiler(engine.Compiled):
                     pd[self.bind_names[bindparam]] = bindparam.value
             return pd
 
-    params = property(construct_params)
+    params = property(construct_params, doc="""
+        Return the bind params for this compiled object.
+
+    """)
 
     def default_from(self):
         """Called when a SELECT statement has no froms, and no FROM clause is to be appended.
@@ -267,10 +272,11 @@ class DefaultCompiler(engine.Compiled):
                     self._truncated_identifier("colident", label.name) or label.name
 
             if result_map is not None:
-                result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
+                result_map[labelname.lower()] = \
+                        (label.name, (label, label.element, labelname), label.element.type)
 
-            return self.process(label.element) + " " + \
-                        self.operator_string(operators.as_) + " " + \
+            return self.process(label.element) + \
+                        OPERATORS[operators.as_] + \
                         self.preparer.format_label(label, labelname)
         else:
             return self.process(label.element)
@@ -292,14 +298,17 @@ class DefaultCompiler(engine.Compiled):
             return name
         else:
             if column.table.schema:
-                schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.'
+                schema_prefix = self.preparer.quote_schema(
+                                    column.table.schema, 
+                                    column.table.quote_schema) + '.'
             else:
                 schema_prefix = ''
             tablename = column.table.name
             tablename = isinstance(tablename, sql._generated_label) and \
                             self._truncated_identifier("alias", tablename) or tablename
             
-            return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name
+            return schema_prefix + \
+                    self.preparer.quote(tablename, column.table.quote) + "." + name
 
     def escape_literal_column(self, text):
         """provide escaping for the literal_column() construct."""
@@ -314,7 +323,7 @@ class DefaultCompiler(engine.Compiled):
         return index.name
 
     def visit_typeclause(self, typeclause, **kwargs):
-        return typeclause.type.dialect_impl(self.dialect).get_col_spec()
+        return self.dialect.type_compiler.process(typeclause.type)
 
     def post_process_text(self, text):
         return text
@@ -343,10 +352,8 @@ class DefaultCompiler(engine.Compiled):
         sep = clauselist.operator
         if sep is None:
             sep = " "
-        elif sep is operators.comma_op:
-            sep = ', '
         else:
-            sep = " " + self.operator_string(clauselist.operator) + " "
+            sep = OPERATORS[clauselist.operator]
         return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
                         if s is not None)
 
@@ -362,7 +369,8 @@ class DefaultCompiler(engine.Compiled):
         return x
 
     def visit_cast(self, cast, **kwargs):
-        return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
+        return "CAST(%s AS %s)" % \
+                    (self.process(cast.clause), self.process(cast.typeclause))
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
@@ -372,26 +380,26 @@ class DefaultCompiler(engine.Compiled):
         if result_map is not None:
             result_map[func.name.lower()] = (func.name, None, func.type)
 
-        name = self.function_string(func)
-
-        if util.callable(name):
-            return name(*[self.process(x) for x in func.clauses])
+        disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+        if disp:
+            return disp(func, **kwargs)
         else:
-            return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+            name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+            return ".".join(func.packagenames + [name]) % \
+                            {'expr':self.function_argspec(func, **kwargs)}
 
     def function_argspec(self, func, **kwargs):
         return self.process(func.clause_expr, **kwargs)
 
-    def function_string(self, func):
-        return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
-
     def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
         entry = self.stack and self.stack[-1] or {}
         self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
 
-        text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i)
-                            for i, c in enumerate(cs.selects)),
-                           " " + cs.keyword + " ")
+        text = (" " + cs.keyword + " ").join(
+                            (self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+                            for i, c in enumerate(cs.selects))
+                        )
+                        
         group_by = self.process(cs._group_by_clause, asfrom=asfrom)
         if group_by:
             text += " GROUP BY " + group_by
@@ -408,27 +416,57 @@ class DefaultCompiler(engine.Compiled):
     def visit_unary(self, unary, **kw):
         s = self.process(unary.element, **kw)
         if unary.operator:
-            s = self.operator_string(unary.operator) + " " + s
+            s = OPERATORS[unary.operator] + s
         if unary.modifier:
-            s = s + " " + self.operator_string(unary.modifier)
+            s = s + OPERATORS[unary.modifier]
         return s
 
     def visit_binary(self, binary, **kwargs):
-        op = self.operator_string(binary.operator)
-        if util.callable(op):
-            return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
-        else:
-            return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+        
+        return self._operator_dispatch(binary.operator,
+                    binary,
+                    lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+                    **kwargs
+        )
 
-    def operator_string(self, operator):
-        return self.operators.get(operator, str(operator))
+    def visit_like_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
 
+    def visit_notlike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+        
+    def visit_ilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+    
+    def visit_notilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+        
+    def _operator_dispatch(self, operator, element, fn, **kw):
+        if util.callable(operator):
+            disp = getattr(self, "visit_%s" % operator.__name__, None)
+            if disp:
+                return disp(element, **kw)
+            else:
+                return fn(OPERATORS[operator])
+        else:
+            return fn(" " + operator + " ")
+        
     def visit_bindparam(self, bindparam, **kwargs):
         name = self._truncate_bindparam(bindparam)
         if name in self.binds:
             existing = self.binds[name]
             if existing is not bindparam and (existing.unique or bindparam.unique):
-                raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+                raise exc.CompileError(
+                        "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key
+                    )
         self.binds[bindparam.key] = self.binds[name] = bindparam
         return self.bindparam_string(name)
 
@@ -491,7 +529,7 @@ class DefaultCompiler(engine.Compiled):
         if isinstance(column, sql._Label):
             return column
 
-        if select.use_labels and column._label:
+        if select and select.use_labels and column._label:
             return _CompileLabel(column, column._label)
 
         if \
@@ -501,13 +539,15 @@ class DefaultCompiler(engine.Compiled):
             column.table is not None and \
             not isinstance(column.table, sql.Select):
             return _CompileLabel(column, sql._generated_label(column.name))
-        elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
+        elif not isinstance(column, 
+                    (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
                 and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
             return _CompileLabel(column, column.anon_label)
         else:
             return column
 
-    def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs):
+    def visit_select(self, select, asfrom=False, parens=True, 
+                            iswrapper=False, compound_index=1, **kwargs):
 
         entry = self.stack and self.stack[-1] or {}
         
@@ -583,8 +623,10 @@ class DefaultCompiler(engine.Compiled):
             return text
 
     def get_select_precolumns(self, select):
-        """Called when building a ``SELECT`` statement, position is just before column list."""
-
+        """Called when building a ``SELECT`` statement, position is just before 
+        column list.
+        
+        """
         return select._distinct and "DISTINCT " or ""
 
     def order_by_clause(self, select):
@@ -613,14 +655,16 @@ class DefaultCompiler(engine.Compiled):
     def visit_table(self, table, asfrom=False, **kwargs):
         if asfrom:
             if getattr(table, "schema", None):
-                return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
+                return self.preparer.quote_schema(table.schema, table.quote_schema) + \
+                                "." + self.preparer.quote(table.name, table.quote)
             else:
                 return self.preparer.quote(table.name, table.quote)
         else:
             return ""
 
     def visit_join(self, join, asfrom=False, **kwargs):
-        return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+        return (self.process(join.left, asfrom=True) + \
+                (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
             self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
 
     def visit_sequence(self, seq):
@@ -629,41 +673,75 @@ class DefaultCompiler(engine.Compiled):
     def visit_insert(self, insert_stmt):
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt)
-        preparer = self.preparer
-
-        insert = ' '.join(["INSERT"] +
-                          [self.process(x) for x in insert_stmt._prefixes])
 
-        if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
+        if not colparams and \
+                not self.dialect.supports_default_values and \
+                not self.dialect.supports_empty_insert:
             raise exc.CompileError(
                 "The version of %s you are using does not support empty inserts." % self.dialect.name)
-        elif not colparams and self.dialect.supports_default_values:
-            return (insert + " INTO %s DEFAULT VALUES" % (
-                (preparer.format_table(insert_stmt.table),)))
-        else: 
-            return (insert + " INTO %s (%s) VALUES (%s)" %
-                (preparer.format_table(insert_stmt.table),
-                 ', '.join([preparer.format_column(c[0])
-                           for c in colparams]),
-                 ', '.join([c[1] for c in colparams])))
 
+        preparer = self.preparer
+        supports_default_values = self.dialect.supports_default_values
+        
+        text = "INSERT"
+        
+        prefixes = [self.process(x) for x in insert_stmt._prefixes]
+        if prefixes:
+            text += " " + " ".join(prefixes)
+        
+        text += " INTO " + preparer.format_table(insert_stmt.table)
+         
+        if colparams or not supports_default_values:
+            text += " (%s)" % ', '.join([preparer.format_column(c[0])
+                       for c in colparams])
+
+        if self.returning or insert_stmt._returning:
+            self.returning = self.returning or insert_stmt._returning
+            returning_clause = self.returning_clause(insert_stmt, self.returning)
+            
+            # cheating
+            if returning_clause.startswith("OUTPUT"):
+                text += " " + returning_clause
+                returning_clause = None
+
+        if not colparams and supports_default_values:
+            text += " DEFAULT VALUES"
+        else:
+            text += " VALUES (%s)" % \
+                     ', '.join([c[1] for c in colparams])
+        
+        if self.returning and returning_clause:
+            text += " " + returning_clause
+        
+        return text
+        
     def visit_update(self, update_stmt):
         self.stack.append({'from': set([update_stmt.table])})
 
         self.isupdate = True
         colparams = self._get_colparams(update_stmt)
 
-        text = ' '.join((
-            "UPDATE",
-            self.preparer.format_table(update_stmt.table),
-            'SET',
-            ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
-                      for c in colparams)
-            ))
-
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+        
+        text += ' SET ' + \
+                ', '.join(
+                        self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
+                      for c in colparams
+                )
+
+        if update_stmt._returning:
+            self.returning = update_stmt._returning
+            returning_clause = self.returning_clause(update_stmt, update_stmt._returning)
+            if returning_clause.startswith("OUTPUT"):
+                text += " " + returning_clause
+                returning_clause = None
+                
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
 
+        if self.returning and returning_clause:
+            text += " " + returning_clause
+            
         self.stack.pop(-1)
 
         return text
@@ -681,7 +759,8 @@ class DefaultCompiler(engine.Compiled):
 
         self.postfetch = []
         self.prefetch = []
-
+        self.returning = []
+        
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
@@ -701,6 +780,15 @@ class DefaultCompiler(engine.Compiled):
 
         # create a list of column assignment clauses as tuples
         values = []
+        
+        need_pks = self.isinsert and \
+                        not self.inline and \
+                        not self.statement._returning
+        
+        implicit_returning = need_pks and \
+                                self.dialect.implicit_returning and \
+                                stmt.table.implicit_returning
+        
         for c in stmt.table.columns:
             if c.key in parameters:
                 value = parameters[c.key]
@@ -710,19 +798,48 @@ class DefaultCompiler(engine.Compiled):
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
                 values.append((c, value))
+
             elif isinstance(c, schema.Column):
                 if self.isinsert:
-                    if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline):
-                        if (((isinstance(c.default, schema.Sequence) and
-                              not c.default.optional) or
-                             not self.dialect.supports_pk_autoincrement) or
-                            (c.default is not None and
-                             not isinstance(c.default, schema.Sequence))):
-                            values.append((c, create_bind_param(c, None)))
-                            self.prefetch.append(c)
+                    if c.primary_key and \
+                        need_pks and \
+                        (
+                            c is not stmt.table._autoincrement_column or 
+                            not self.dialect.postfetch_lastrowid
+                        ):
+                        
+                        if implicit_returning:
+                            if isinstance(c.default, schema.Sequence):
+                                proc = self.process(c.default)
+                                if proc is not None:
+                                    values.append((c, proc))
+                                self.returning.append(c)
+                            elif isinstance(c.default, schema.ColumnDefault) and \
+                                        isinstance(c.default.arg, sql.ClauseElement):
+                                values.append((c, self.process(c.default.arg.self_group())))
+                                self.returning.append(c)
+                            elif c.default is not None:
+                                values.append((c, create_bind_param(c, None)))
+                                self.prefetch.append(c)
+                            else:
+                                self.returning.append(c)
+                        else:
+                            if (
+                                c.default is not None and \
+                                    (
+                                        self.dialect.supports_sequences or 
+                                        not isinstance(c.default, schema.Sequence)
+                                    )
+                                ) or \
+                                self.dialect.preexecute_autoincrement_sequences:
+
+                                values.append((c, create_bind_param(c, None)))
+                                self.prefetch.append(c)
+                                
                     elif isinstance(c.default, schema.ColumnDefault):
                         if isinstance(c.default.arg, sql.ClauseElement):
                             values.append((c, self.process(c.default.arg.self_group())))
+                            
                             if not c.primary_key:
                                 # dont add primary key column to postfetch
                                 self.postfetch.append(c)
@@ -759,9 +876,19 @@ class DefaultCompiler(engine.Compiled):
 
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
+        if delete_stmt._returning:
+            self.returning = delete_stmt._returning
+            returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning)
+            if returning_clause.startswith("OUTPUT"):
+                text += " " + returning_clause
+                returning_clause = None
+                
         if delete_stmt._whereclause:
             text += " WHERE " + self.process(delete_stmt._whereclause)
 
+        if self.returning and returning_clause:
+            text += " " + returning_clause
+            
         self.stack.pop(-1)
 
         return text
@@ -775,110 +902,146 @@ class DefaultCompiler(engine.Compiled):
     def visit_release_savepoint(self, savepoint_stmt):
         return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
 
-    def __str__(self):
-        return self.string or ''
-
-class DDLBase(engine.SchemaIterator):
-    def find_alterables(self, tables):
-        alterables = []
-        class FindAlterables(schema.SchemaVisitor):
-            def visit_foreign_key_constraint(self, constraint):
-                if constraint.use_alter and constraint.table in tables:
-                    alterables.append(constraint)
-        findalterables = FindAlterables()
-        for table in tables:
-            for c in table.constraints:
-                findalterables.traverse(c)
-        return alterables
-
-    def _validate_identifier(self, ident, truncate):
-        if truncate:
-            if len(ident) > self.dialect.max_identifier_length:
-                counter = getattr(self, 'counter', 0)
-                self.counter = counter + 1
-                return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
-            else:
-                return ident
-        else:
-            self.dialect.validate_identifier(ident)
-            return ident
-
-
-class SchemaGenerator(DDLBase):
-    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
-        super(SchemaGenerator, self).__init__(connection, **kwargs)
-        self.checkfirst = checkfirst
-        self.tables = tables and set(tables) or None
-        self.preparer = dialect.identifier_preparer
-        self.dialect = dialect
 
-    def get_column_specification(self, column, first_pk=False):
-        raise NotImplementedError()
+class DDLCompiler(engine.Compiled):
+    @property
+    def preparer(self):
+        return self.dialect.identifier_preparer
 
-    def _can_create(self, table):
-        self.dialect.validate_identifier(table.name)
-        if table.schema:
-            self.dialect.validate_identifier(table.schema)
-        return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+    def construct_params(self, params=None):
+        return None
+        
+    def visit_ddl(self, ddl, **kwargs):
+        # table events can substitute table and schema name
+        context = ddl.context
+        if isinstance(ddl.schema_item, schema.Table):
+            context = context.copy()
+
+            preparer = self.dialect.identifier_preparer
+            path = preparer.format_table_seq(ddl.schema_item)
+            if len(path) == 1:
+                table, sch = path[0], ''
+            else:
+                table, sch = path[-1], path[0]
 
-    def visit_metadata(self, metadata):
-        if self.tables:
-            tables = self.tables
-        else:
-            tables = metadata.tables.values()
-        collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
-        for table in collection:
-            self.traverse_single(table)
-        if self.dialect.supports_alter:
-            for alterable in self.find_alterables(collection):
-                self.add_foreignkey(alterable)
-
-    def visit_table(self, table):
-        for listener in table.ddl_listeners['before-create']:
-            listener('before-create', table, self.connection)
+            context.setdefault('table', table)
+            context.setdefault('schema', sch)
+            context.setdefault('fullname', preparer.format_table(ddl.schema_item))
+        
+        return ddl.statement % context
 
-        for column in table.columns:
-            if column.default is not None:
-                self.traverse_single(column.default)
+    def visit_create_table(self, create):
+        table = create.element
+        preparer = self.dialect.identifier_preparer
 
-        self.append("\n" + " ".join(['CREATE'] +
-                                    table._prefixes +
+        text = "\n" + " ".join(['CREATE'] + \
+                                    table._prefixes + \
                                     ['TABLE',
-                                     self.preparer.format_table(table),
-                                     "("]))
+                                     preparer.format_table(table),
+                                     "("])
         separator = "\n"
 
         # if only one primary key, specify it along with the column
         first_pk = False
         for column in table.columns:
-            self.append(separator)
+            text += separator
             separator = ", \n"
-            self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
+            text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
             if column.primary_key:
                 first_pk = True
-            for constraint in column.constraints:
-                self.traverse_single(constraint)
+            const = " ".join(self.process(constraint) for constraint in column.constraints)
+            if const:
+                text += " " + const
 
         # On some DB order is significant: visit PK first, then the
         # other constraints (engine.ReflectionTest.testbasic failed on FB2)
         if table.primary_key:
-            self.traverse_single(table.primary_key)
-        for constraint in [c for c in table.constraints if c is not table.primary_key]:
-            self.traverse_single(constraint)
+            text += ", \n\t" + self.process(table.primary_key)
+        
+        const = ", \n\t".join(
+                        self.process(constraint) for constraint in table.constraints 
+                        if constraint is not table.primary_key
+                        and constraint.inline_ddl
+                        and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False))
+                )
+        if const:
+            text += ", \n\t" + const
+
+        text += "\n)%s\n\n" % self.post_create_table(table)
+        return text
+        
+    def visit_drop_table(self, drop):
+        ret = "\nDROP TABLE " + self.preparer.format_table(drop.element)
+        if drop.cascade:
+            ret += " CASCADE CONSTRAINTS"
+        return ret
+        
+    def visit_create_index(self, create):
+        index = create.element
+        preparer = self.preparer
+        text = "CREATE "
+        if index.unique:
+            text += "UNIQUE "
+        text += "INDEX %s ON %s (%s)" \
+                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+                       preparer.format_table(index.table),
+                       ', '.join(preparer.quote(c.name, c.quote)
+                                 for c in index.columns))
+        return text
 
-        self.append("\n)%s\n\n" % self.post_create_table(table))
-        self.execute()
+    def visit_drop_index(self, drop):
+        index = drop.element
+        return "\nDROP INDEX " + \
+                    self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
 
-        if hasattr(table, 'indexes'):
-            for index in table.indexes:
-                self.traverse_single(index)
+    def visit_add_constraint(self, create):
+        preparer = self.preparer
+        return "ALTER TABLE %s ADD %s" % (
+            self.preparer.format_table(create.element.table),
+            self.process(create.element)
+        )
+        
+    def visit_drop_constraint(self, drop):
+        preparer = self.preparer
+        return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
+            self.preparer.format_table(drop.element.table),
+            self.preparer.format_constraint(drop.element),
+            " CASCADE" if drop.cascade else ""
+        )
+    
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column) + " " + \
+                        self.dialect.type_compiler.process(column.type)
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
 
-        for listener in table.ddl_listeners['after-create']:
-            listener('after-create', table, self.connection)
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
 
     def post_create_table(self, table):
         return ''
 
+    def _compile(self, tocompile, parameters):
+        """compile the given string/parameters using this SchemaGenerator's dialect."""
+        
+        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
+        compiler.compile()
+        return compiler
+
+    def _validate_identifier(self, ident, truncate):
+        if truncate:
+            if len(ident) > self.dialect.max_identifier_length:
+                counter = getattr(self, 'counter', 0)
+                self.counter = counter + 1
+                return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+            else:
+                return ident
+        else:
+            self.dialect.validate_identifier(ident)
+            return ident
+
     def get_column_default_string(self, column):
         if isinstance(column.server_default, schema.DefaultClause):
             if isinstance(column.server_default.arg, basestring):
@@ -888,149 +1051,190 @@ class SchemaGenerator(DDLBase):
         else:
             return None
 
-    def _compile(self, tocompile, parameters):
-        """compile the given string/parameters using this SchemaGenerator's dialect."""
-        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
-        compiler.compile()
-        return compiler
-
     def visit_check_constraint(self, constraint):
-        self.append(", \n\t")
+        text = ""
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " %
-                        self.preparer.format_constraint(constraint))
-        self.append(" CHECK (%s)" % constraint.sqltext)
-        self.define_constraint_deferrability(constraint)
+            text += "CONSTRAINT %s " % \
+                        self.preparer.format_constraint(constraint)
+        text += " CHECK (%s)" % constraint.sqltext
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_column_check_constraint(self, constraint):
-        self.append(" CHECK (%s)" % constraint.sqltext)
-        self.define_constraint_deferrability(constraint)
+        text = " CHECK (%s)" % constraint.sqltext
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_primary_key_constraint(self, constraint):
         if len(constraint) == 0:
-            return
-        self.append(", \n\t")
+            return ''
+        text = ""
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
-        self.append("PRIMARY KEY ")
-        self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
-                                       for c in constraint))
-        self.define_constraint_deferrability(constraint)
+            text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+        text += "PRIMARY KEY "
+        text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+                                       for c in constraint)
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_foreign_key_constraint(self, constraint):
-        if constraint.use_alter and self.dialect.supports_alter:
-            return
-        self.append(", \n\t ")
-        self.define_foreign_key(constraint)
-
-    def add_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
-        self.define_foreign_key(constraint)
-        self.execute()
-
-    def define_foreign_key(self, constraint):
-        preparer = self.preparer
+        preparer = self.dialect.identifier_preparer
+        text = ""
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " %
-                        preparer.format_constraint(constraint))
-        table = list(constraint.elements)[0].column.table
-        self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+            text += "CONSTRAINT %s " % \
+                        preparer.format_constraint(constraint)
+        remote_table = list(constraint._elements.values())[0].column.table
+        text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
             ', '.join(preparer.quote(f.parent.name, f.parent.quote)
-                      for f in constraint.elements),
-            preparer.format_table(table),
+                      for f in constraint._elements.values()),
+            preparer.format_table(remote_table),
             ', '.join(preparer.quote(f.column.name, f.column.quote)
-                      for f in constraint.elements)
-        ))
-        if constraint.ondelete is not None:
-            self.append(" ON DELETE %s" % constraint.ondelete)
-        if constraint.onupdate is not None:
-            self.append(" ON UPDATE %s" % constraint.onupdate)
-        self.define_constraint_deferrability(constraint)
+                      for f in constraint._elements.values())
+        )
+        text += self.define_constraint_cascades(constraint)
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_unique_constraint(self, constraint):
-        self.append(", \n\t")
+        text = ""
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " %
-                        self.preparer.format_constraint(constraint))
-        self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)))
-        self.define_constraint_deferrability(constraint)
+            text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+        text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
+    def define_constraint_cascades(self, constraint):
+        text = ""
+        if constraint.ondelete is not None:
+            text += " ON DELETE %s" % constraint.ondelete
+        if constraint.onupdate is not None:
+            text += " ON UPDATE %s" % constraint.onupdate
+        return text
+        
     def define_constraint_deferrability(self, constraint):
+        text = ""
         if constraint.deferrable is not None:
             if constraint.deferrable:
-                self.append(" DEFERRABLE")
+                text += " DEFERRABLE"
             else:
-                self.append(" NOT DEFERRABLE")
+                text += " NOT DEFERRABLE"
         if constraint.initially is not None:
-            self.append(" INITIALLY %s" % constraint.initially)
+            text += " INITIALLY %s" % constraint.initially
+        return text
+        
+        
+class GenericTypeCompiler(engine.TypeCompiler):
+    def visit_CHAR(self, type_):
+        return "CHAR" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_column(self, column):
-        pass
+    def visit_NCHAR(self, type_):
+        return "NCHAR" + (type_.length and "(%d)" % type_.length or "")
+    
+    def visit_FLOAT(self, type_):
+        return "FLOAT"
 
-    def visit_index(self, index):
-        preparer = self.preparer
-        self.append("CREATE ")
-        if index.unique:
-            self.append("UNIQUE ")
-        self.append("INDEX %s ON %s (%s)" \
-                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
-                       preparer.format_table(index.table),
-                       ', '.join(preparer.quote(c.name, c.quote)
-                                 for c in index.columns)))
-        self.execute()
+    def visit_NUMERIC(self, type_):
+        if type_.precision is None:
+            return "NUMERIC"
+        else:
+            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
 
+    def visit_DECIMAL(self, type_):
+        return "DECIMAL"
+        
+    def visit_INTEGER(self, type_):
+        return "INTEGER"
 
-class SchemaDropper(DDLBase):
-    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
-        super(SchemaDropper, self).__init__(connection, **kwargs)
-        self.checkfirst = checkfirst
-        self.tables = tables
-        self.preparer = dialect.identifier_preparer
-        self.dialect = dialect
+    def visit_SMALLINT(self, type_):
+        return "SMALLINT"
 
-    def visit_metadata(self, metadata):
-        if self.tables:
-            tables = self.tables
-        else:
-            tables = metadata.tables.values()
-        collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
-        if self.dialect.supports_alter:
-            for alterable in self.find_alterables(collection):
-                self.drop_foreignkey(alterable)
-        for table in collection:
-            self.traverse_single(table)
-
-    def _can_drop(self, table):
-        self.dialect.validate_identifier(table.name)
-        if table.schema:
-            self.dialect.validate_identifier(table.schema)
-        return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
-
-    def visit_index(self, index):
-        self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote))
-        self.execute()
-
-    def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
-            self.preparer.format_table(constraint.table),
-            self.preparer.format_constraint(constraint)))
-        self.execute()
-
-    def visit_table(self, table):
-        for listener in table.ddl_listeners['before-drop']:
-            listener('before-drop', table, self.connection)
+    def visit_BIGINT(self, type_):
+        return "BIGINT"
 
-        for column in table.columns:
-            if column.default is not None:
-                self.traverse_single(column.default)
+    def visit_TIMESTAMP(self, type_):
+        return 'TIMESTAMP'
+
+    def visit_DATETIME(self, type_):
+        return "DATETIME"
+
+    def visit_DATE(self, type_):
+        return "DATE"
+
+    def visit_TIME(self, type_):
+        return "TIME"
+
+    def visit_CLOB(self, type_):
+        return "CLOB"
 
-        self.append("\nDROP TABLE " + self.preparer.format_table(table))
-        self.execute()
+    def visit_NCLOB(self, type_):
+        return "NCLOB"
 
-        for listener in table.ddl_listeners['after-drop']:
-            listener('after-drop', table, self.connection)
+    def visit_VARCHAR(self, type_):
+        return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
 
+    def visit_NVARCHAR(self, type_):
+        return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "")
 
+    def visit_BLOB(self, type_):
+        return "BLOB"
+    
+    def visit_BOOLEAN(self, type_):
+        return "BOOLEAN"
+    
+    def visit_TEXT(self, type_):
+        return "TEXT"
+    
+    def visit_binary(self, type_):
+        return self.visit_BLOB(type_)
+        
+    def visit_boolean(self, type_): 
+        return self.visit_BOOLEAN(type_)
+        
+    def visit_time(self, type_): 
+        return self.visit_TIME(type_)
+        
+    def visit_datetime(self, type_): 
+        return self.visit_DATETIME(type_)
+        
+    def visit_date(self, type_): 
+        return self.visit_DATE(type_)
+
+    def visit_big_integer(self, type_): 
+        return self.visit_BIGINT(type_)
+        
+    def visit_small_integer(self, type_): 
+        return self.visit_SMALLINT(type_)
+        
+    def visit_integer(self, type_): 
+        return self.visit_INTEGER(type_)
+        
+    def visit_float(self, type_):
+        return self.visit_FLOAT(type_)
+        
+    def visit_numeric(self, type_): 
+        return self.visit_NUMERIC(type_)
+        
+    def visit_string(self, type_): 
+        return self.visit_VARCHAR(type_)
+        
+    def visit_unicode(self, type_): 
+        return self.visit_VARCHAR(type_)
+
+    def visit_text(self, type_): 
+        return self.visit_TEXT(type_)
+
+    def visit_unicode_text(self, type_): 
+        return self.visit_TEXT(type_)
+    
+    def visit_null(self, type_):
+        raise NotImplementedError("Can't generate DDL for the null type")
+        
+    def visit_type_decorator(self, type_):
+        return self.process(type_.type_engine(self.dialect))
+        
+    def visit_user_defined(self, type_):
+        return type_.get_col_spec()
+    
 class IdentifierPreparer(object):
     """Handle quoting and case-folding of identifiers based on options."""
 
@@ -1176,24 +1380,24 @@ class IdentifierPreparer(object):
         else:
             return (self.format_table(table, use_schema=False), )
 
+    @util.memoized_property
+    def _r_identifiers(self):
+        initial, final, escaped_final = \
+                 [re.escape(s) for s in
+                  (self.initial_quote, self.final_quote,
+                   self._escape_identifier(self.final_quote))]
+        r = re.compile(
+            r'(?:'
+            r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
+            r'|([^\.]+))(?=\.|$))+' %
+            { 'initial': initial,
+              'final': final,
+              'escaped': escaped_final })
+        return r
+        
     def unformat_identifiers(self, identifiers):
         """Unpack 'schema.table.column'-like strings into components."""
 
-        try:
-            r = self._r_identifiers
-        except AttributeError:
-            initial, final, escaped_final = \
-                     [re.escape(s) for s in
-                      (self.initial_quote, self.final_quote,
-                       self._escape_identifier(self.final_quote))]
-            r = re.compile(
-                r'(?:'
-                r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
-                r'|([^\.]+))(?=\.|$))+' %
-                { 'initial': initial,
-                  'final': final,
-                  'escaped': escaped_final })
-            self._r_identifiers = r
-
+        r = self._r_identifiers
         return [self._unescape_identifier(i)
                 for i in [a or b for a, b in r.findall(identifiers)]]
index 83897ef051cca420576616ee5c7b8dcaf5749f58..91e0e74ae45e7812d09aeddf56c5bd17c6ea43c4 100644 (file)
@@ -29,10 +29,9 @@ to stay the same in future releases.
 import itertools, re
 from operator import attrgetter
 
-from sqlalchemy import util, exc
+from sqlalchemy import util, exc, types as sqltypes
 from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import Visitable, cloned_traverse
-from sqlalchemy import types as sqltypes
 import operator
 
 functions, schema, sql_util = None, None, None
@@ -128,7 +127,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs):
     Similar functionality is also available via the ``select()``
     method on any :class:`~sqlalchemy.sql.expression.FromClause`.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.Select`.
+    The returned object is an instance of
+     :class:`~sqlalchemy.sql.expression.Select`.
 
     All arguments which accept ``ClauseElement`` arguments also accept
     string arguments, which will be converted as appropriate into
@@ -241,7 +241,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs):
 
     """
     if 'scalar' in kwargs:
-        util.warn_deprecated('scalar option is deprecated; see docs for details')
+        util.warn_deprecated(
+            'scalar option is deprecated; see docs for details')
     scalar = kwargs.pop('scalar', False)
     s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
     if scalar:
@@ -250,15 +251,16 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs):
         return s
 
 def subquery(alias, *args, **kwargs):
-    """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived from a :class:`~sqlalchemy.sql.expression.Select`.
+    """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived 
+    from a :class:`~sqlalchemy.sql.expression.Select`.
 
     name
       alias name
 
     \*args, \**kwargs
 
-      all other arguments are delivered to the :func:`~sqlalchemy.sql.expression.select`
-      function.
+      all other arguments are delivered to the
+      :func:`~sqlalchemy.sql.expression.select` function.
 
     """
     return Select(*args, **kwargs).alias(alias)
@@ -280,12 +282,12 @@ def insert(table, values=None, inline=False, **kwargs):
       table columns.  Note that the :meth:`~Insert.values()` generative method
       may also be used for this.
 
-    :param prefixes: A list of modifier keywords to be inserted between INSERT and INTO.
-      Alternatively, the :meth:`~Insert.prefix_with` generative method may be used.
+    :param prefixes: A list of modifier keywords to be inserted between INSERT
+    and INTO. Alternatively, the :meth:`~Insert.prefix_with` generative method
+    may be used.
 
-    :param inline:
-      if True, SQL defaults will be compiled 'inline' into the statement
-      and not pre-executed.
+    :param inline: if True, SQL defaults will be compiled 'inline' into the
+    statement and not pre-executed.
 
     If both `values` and compile-time bind parameters are present, the
     compile-time bind parameters override the information specified
@@ -313,9 +315,9 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs):
 
     :param table: The table to be updated.
 
-    :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition of the
-      ``UPDATE`` statement.  Note that the :meth:`~Update.where()` generative
-      method may also be used for this.
+    :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition
+    of the ``UPDATE`` statement. Note that the :meth:`~Update.where()`
+    generative method may also be used for this.
 
     :param values:
       A dictionary which specifies the ``SET`` conditions of the
@@ -347,7 +349,12 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs):
     against the ``UPDATE`` statement.
 
     """
-    return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs)
+    return Update(
+            table, 
+            whereclause=whereclause, 
+            values=values, 
+            inline=inline, 
+            **kwargs)
 
 def delete(table, whereclause = None, **kwargs):
     """Return a :class:`~sqlalchemy.sql.expression.Delete` clause element.
@@ -357,9 +364,9 @@ def delete(table, whereclause = None, **kwargs):
 
     :param table: The table to be updated.
 
-    :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` condition of the
-      ``UPDATE`` statement.  Note that the :meth:`~Delete.where()` generative method
-      may be used instead.
+    :param whereclause: A :class:`ClauseElement` describing the ``WHERE``
+    condition of the ``UPDATE`` statement. Note that the :meth:`~Delete.where()`
+    generative method may be used instead.
 
     """
     return Delete(table, whereclause, **kwargs)
@@ -368,8 +375,8 @@ def and_(*clauses):
     """Join a list of clauses together using the ``AND`` operator.
 
     The ``&`` operator is also overloaded on all
-    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same
-    result.
+    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the
+    same result.
 
     """
     if len(clauses) == 1:
@@ -380,8 +387,8 @@ def or_(*clauses):
     """Join a list of clauses together using the ``OR`` operator.
 
     The ``|`` operator is also overloaded on all
-    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same
-    result.
+    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the
+    same result.
 
     """
     if len(clauses) == 1:
@@ -392,8 +399,8 @@ def not_(clause):
     """Return a negation of the given clause, i.e. ``NOT(clause)``.
 
     The ``~`` operator is also overloaded on all
-    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same
-    result.
+    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the
+    same result.
 
     """
     return operators.inv(_literal_as_binds(clause))
@@ -408,8 +415,9 @@ def between(ctest, cleft, cright):
 
     Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``.
 
-    The ``between()`` method on all :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses
-    provides similar functionality.
+    The ``between()`` method on all
+    :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses provides
+    similar functionality.
 
     """
     ctest = _literal_as_binds(ctest)
@@ -517,7 +525,8 @@ def exists(*args, **kwargs):
 def union(*selects, **kwargs):
     """Return a ``UNION`` of multiple selectables.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+    The returned object is an instance of
+    :class:`~sqlalchemy.sql.expression.CompoundSelect`.
 
     A similar ``union()`` method is available on all
     :class:`~sqlalchemy.sql.expression.FromClause` subclasses.
@@ -535,7 +544,8 @@ def union(*selects, **kwargs):
 def union_all(*selects, **kwargs):
     """Return a ``UNION ALL`` of multiple selectables.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+    The returned object is an instance of
+    :class:`~sqlalchemy.sql.expression.CompoundSelect`.
 
     A similar ``union_all()`` method is available on all
     :class:`~sqlalchemy.sql.expression.FromClause` subclasses.
@@ -553,7 +563,8 @@ def union_all(*selects, **kwargs):
 def except_(*selects, **kwargs):
     """Return an ``EXCEPT`` of multiple selectables.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+    The returned object is an instance of
+    :class:`~sqlalchemy.sql.expression.CompoundSelect`.
 
     \*selects
       a list of :class:`~sqlalchemy.sql.expression.Select` instances.
@@ -568,7 +579,8 @@ def except_(*selects, **kwargs):
 def except_all(*selects, **kwargs):
     """Return an ``EXCEPT ALL`` of multiple selectables.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+    The returned object is an instance of
+    :class:`~sqlalchemy.sql.expression.CompoundSelect`.
 
     \*selects
       a list of :class:`~sqlalchemy.sql.expression.Select` instances.
@@ -583,7 +595,8 @@ def except_all(*selects, **kwargs):
 def intersect(*selects, **kwargs):
     """Return an ``INTERSECT`` of multiple selectables.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+    The returned object is an instance of
+    :class:`~sqlalchemy.sql.expression.CompoundSelect`.
 
     \*selects
       a list of :class:`~sqlalchemy.sql.expression.Select` instances.
@@ -598,7 +611,8 @@ def intersect(*selects, **kwargs):
 def intersect_all(*selects, **kwargs):
     """Return an ``INTERSECT ALL`` of multiple selectables.
 
-    The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+    The returned object is an instance of
+    :class:`~sqlalchemy.sql.expression.CompoundSelect`.
 
     \*selects
       a list of :class:`~sqlalchemy.sql.expression.Select` instances.
@@ -613,8 +627,8 @@ def intersect_all(*selects, **kwargs):
 def alias(selectable, alias=None):
     """Return an :class:`~sqlalchemy.sql.expression.Alias` object.
 
-    An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause` with
-    an alternate name assigned within SQL, typically using the ``AS``
+    An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause`
+    with an alternate name assigned within SQL, typically using the ``AS``
     clause when generated, e.g. ``SELECT * FROM table AS aliasname``.
 
     Similar functionality is available via the ``alias()`` method
@@ -656,7 +670,8 @@ def literal(value, type_=None):
     return _BindParamClause(None, value, type_=type_, unique=True)
 
 def label(name, obj):
-    """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given :class:`~sqlalchemy.sql.expression.ColumnElement`.
+    """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given
+    :class:`~sqlalchemy.sql.expression.ColumnElement`.
 
     A label changes the name of an element in the columns clause of a
     ``SELECT`` statement, typically via the ``AS`` SQL keyword.
@@ -674,11 +689,13 @@ def label(name, obj):
     return _Label(name, obj)
 
 def column(text, type_=None):
-    """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement.
+    """Return a textual column clause, as would be in the columns clause of a
+    ``SELECT`` statement.
 
-    The object returned is an instance of :class:`~sqlalchemy.sql.expression.ColumnClause`,
-    which represents the "syntactical" portion of the schema-level
-    :class:`~sqlalchemy.schema.Column` object.
+    The object returned is an instance of
+    :class:`~sqlalchemy.sql.expression.ColumnClause`, which represents the
+    "syntactical" portion of the schema-level :class:`~sqlalchemy.schema.Column`
+    object.
 
     text
       the name of the column.  Quoting rules will be applied to the
@@ -710,9 +727,9 @@ def literal_column(text, type_=None):
       :func:`~sqlalchemy.sql.expression.column` function.
 
     type\_
-      an optional :class:`~sqlalchemy.types.TypeEngine` object which will provide
-      result-set translation and additional expression semantics for this
-      column.  If left as None the type will be NullType.
+      an optional :class:`~sqlalchemy.types.TypeEngine` object which will
+      provide result-set translation and additional expression semantics for
+      this column. If left as None the type will be NullType.
 
     """
     return ColumnClause(text, type_=type_, is_literal=True)
@@ -752,7 +769,8 @@ def bindparam(key, value=None, shortname=None, type_=None, unique=False):
         return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname)
 
 def outparam(key, type_=None):
-    """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them.
+    """Create an 'OUT' parameter for usage in functions (stored procedures), for
+    databases which support them.
 
     The ``outparam`` can be used like a regular function parameter.
     The "output" value will be available from the
@@ -760,7 +778,8 @@ def outparam(key, type_=None):
     attribute, which returns a dictionary containing the values.
 
     """
-    return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
+    return _BindParamClause(
+                key, None, type_=type_, unique=False, isoutparam=True)
 
 def text(text, bind=None, *args, **kwargs):
     """Create literal text to be inserted into a query.
@@ -803,8 +822,10 @@ def text(text, bind=None, *args, **kwargs):
     return _TextClause(text, bind=bind, *args, **kwargs)
 
 def null():
-    """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement."""
-
+    """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql
+    statement.
+    
+    """
     return _Null()
 
 class _FunctionGenerator(object):
@@ -839,7 +860,8 @@ class _FunctionGenerator(object):
             if func is not None:
                 return func(*c, **o)
 
-        return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
+        return Function(
+                    self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
 
 # "func" global - i.e. func.count()
 func = _FunctionGenerator()
@@ -861,10 +883,19 @@ def _clone(element):
     return element._clone()
 
 def _expand_cloned(elements):
-    """expand the given set of ClauseElements to be the set of all 'cloned' predecessors."""
-
+    """expand the given set of ClauseElements to be the set of all 'cloned'
+    predecessors.
+    
+    """
     return itertools.chain(*[x._cloned_set for x in elements])
 
+def _select_iterables(elements):
+    """expand tables into individual columns in the 
+    given list of column expressions.
+    
+    """
+    return itertools.chain(*[c._select_iterable for c in elements])
+    
 def _cloned_intersection(a, b):
     """return the intersection of sets a and b, counting
     any overlap between 'cloned' predecessors.
@@ -879,7 +910,8 @@ def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
 
 def _is_literal(element):
-    return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
+    return not isinstance(element, Visitable) and \
+            not hasattr(element, '__clause_element__')
 
 def _from_objects(*elements):
     return itertools.chain(*[element._from_objects for element in elements])
@@ -940,27 +972,36 @@ def _no_literals(element):
         return element
 
 def _corresponding_column_or_error(fromclause, column, require_embedded=False):
-    c = fromclause.corresponding_column(column, require_embedded=require_embedded)
+    c = fromclause.corresponding_column(column,
+            require_embedded=require_embedded)
     if not c:
-        raise exc.InvalidRequestError("Given column '%s', attached to table '%s', "
+        raise exc.InvalidRequestError(
+                "Given column '%s', attached to table '%s', "
                 "failed to locate a corresponding column from table '%s'"
-                % (column, getattr(column, 'table', None), fromclause.description))
+                % 
+                (column, 
+                    getattr(column, 'table', None),fromclause.description)
+                )
     return c
 
 def is_column(col):
     """True if ``col`` is an instance of ``ColumnElement``."""
+    
     return isinstance(col, ColumnElement)
 
 
 class ClauseElement(Visitable):
-    """Base class for elements of a programmatically constructed SQL expression."""
-
+    """Base class for elements of a programmatically constructed SQL
+    expression.
+    
+    """
     __visit_name__ = 'clause'
 
     _annotations = {}
     supports_execution = False
     _from_objects = []
-
+    _bind = None
+    
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
 
@@ -984,7 +1025,8 @@ class ClauseElement(Visitable):
 
     @util.memoized_property
     def _cloned_set(self):
-        """Return the set consisting all cloned anscestors of this ClauseElement.
+        """Return the set consisting all cloned anscestors of this
+        ClauseElement.
 
         Includes this ClauseElement.  This accessor tends to be used for
         FromClause objects to identify 'equivalent' FROM clauses, regardless
@@ -1004,15 +1046,20 @@ class ClauseElement(Visitable):
         return d
 
     def _annotate(self, values):
-        """return a copy of this ClauseElement with the given annotations dictionary."""
-
+        """return a copy of this ClauseElement with the given annotations
+        dictionary.
+        
+        """
         global Annotated
         if Annotated is None:
             from sqlalchemy.sql.util import Annotated
         return Annotated(self, values)
 
     def _deannotate(self):
-        """return a copy of this ClauseElement with an empty annotations dictionary."""
+        """return a copy of this ClauseElement with an empty annotations
+        dictionary.
+        
+        """
         return self._clone()
 
     def unique_params(self, *optionaldict, **kwargs):
@@ -1044,7 +1091,8 @@ class ClauseElement(Visitable):
         if len(optionaldict) == 1:
             kwargs.update(optionaldict[0])
         elif len(optionaldict) > 1:
-            raise exc.ArgumentError("params() takes zero or one positional dictionary argument")
+            raise exc.ArgumentError(
+                "params() takes zero or one positional dictionary argument")
 
         def visit_bindparam(bind):
             if bind.key in kwargs:
@@ -1088,15 +1136,20 @@ class ClauseElement(Visitable):
     def self_group(self, against=None):
         return self
 
+    # TODO: remove .bind as a method from the root ClauseElement.
+    # we should only be deriving binds from FromClause elements
+    # and certain SchemaItem subclasses.
+    # the "search_for_bind" functionality can still be used by
+    # execute(), however.
     @property
     def bind(self):
-        """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
+        """Returns the Engine or Connection to which this ClauseElement is
+        bound, or None if none found.
+        
+        """
+        if self._bind is not None:
+            return self._bind
 
-        try:
-            if self._bind is not None:
-                return self._bind
-        except AttributeError:
-            pass
         for f in _from_objects(self):
             if f is self:
                 continue
@@ -1121,68 +1174,82 @@ class ClauseElement(Visitable):
         return e._execute_clauseelement(self, multiparams, params)
 
     def scalar(self, *multiparams, **params):
-        """Compile and execute this ``ClauseElement``, returning the result's scalar representation."""
-
+        """Compile and execute this ``ClauseElement``, returning the result's
+        scalar representation.
+        
+        """
         return self.execute(*multiparams, **params).scalar()
 
-    def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False):
+    def compile(self, bind=None, dialect=None, **kw):
         """Compile this SQL expression.
 
         The return value is a :class:`~sqlalchemy.engine.Compiled` object.
-        Calling `str()` or `unicode()` on the returned value will yield
-        a string representation of the result.   The :class:`~sqlalchemy.engine.Compiled`
-        object also can return a dictionary of bind parameter names and
-        values using the `params` accessor.
+        Calling `str()` or `unicode()` on the returned value will yield a string
+        representation of the result. The :class:`~sqlalchemy.engine.Compiled`
+        object also can return a dictionary of bind parameter names and values
+        using the `params` accessor.
 
         :param bind: An ``Engine`` or ``Connection`` from which a
           ``Compiled`` will be acquired.  This argument
           takes precedence over this ``ClauseElement``'s
           bound engine, if any.
 
-        :param column_keys: Used for INSERT and UPDATE statements, a list of
-          column names which should be present in the VALUES clause
-          of the compiled statement.  If ``None``, all columns
-          from the target table object are rendered.
-
-        :param compiler: A ``Compiled`` instance which will be used to compile
-          this expression.  This argument takes precedence
-          over the `bind` and `dialect` arguments as well as
-          this ``ClauseElement``'s bound engine, if
-          any.
-
         :param dialect: A ``Dialect`` instance frmo which a ``Compiled``
           will be acquired.  This argument takes precedence
           over the `bind` argument as well as this
           ``ClauseElement``'s bound engine, if any.
 
-        :param inline: Used for INSERT statements, for a dialect which does
-          not support inline retrieval of newly generated
-          primary key columns, will force the expression used
-          to create the new primary key value to be rendered
-          inline within the INSERT statement's VALUES clause.
-          This typically refers to Sequence execution but
-          may also refer to any server-side default generation
-          function associated with a primary key `Column`.
+        \**kw
+        
+          Keyword arguments are passed along to the compiler, 
+          which can affect the string produced.
+          
+          Keywords for a statement compiler are:
+        
+          column_keys
+            Used for INSERT and UPDATE statements, a list of
+            column names which should be present in the VALUES clause
+            of the compiled statement.  If ``None``, all columns
+            from the target table object are rendered.
+
+          inline
+            Used for INSERT statements, for a dialect which does
+            not support inline retrieval of newly generated
+            primary key columns, will force the expression used
+            to create the new primary key value to be rendered
+            inline within the INSERT statement's VALUES clause.
+            This typically refers to Sequence execution but
+            may also refer to any server-side default generation
+            function associated with a primary key `Column`.
 
         """
-        if compiler is None:
-            if dialect is not None:
-                compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
-            elif bind is not None:
-                compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline)
-            elif self.bind is not None:
-                compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline)
+        
+        if not dialect:
+            if bind:
+                dialect = bind.dialect
+            elif self.bind:
+                dialect = self.bind.dialect
+                bind = self.bind
             else:
                 global DefaultDialect
                 if DefaultDialect is None:
                     from sqlalchemy.engine.default import DefaultDialect
                 dialect = DefaultDialect()
-                compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
+        compiler = self._compiler(dialect, bind=bind, **kw)
         compiler.compile()
         return compiler
-
+    
+    def _compiler(self, dialect, **kw):
+        """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+        
+        return dialect.statement_compiler(dialect, self, **kw)
+        
     def __str__(self):
+        # Py3K
+        #return unicode(self.compile())
+        # Py2K
         return unicode(self.compile()).encode('ascii', 'backslashreplace')
+        # end Py2K
 
     def __and__(self, other):
         return and_(self, other)
@@ -1193,11 +1260,25 @@ class ClauseElement(Visitable):
     def __invert__(self):
         return self._negate()
 
+    if util.jython:
+        def __hash__(self):
+            """Return a distinct hash code.
+
+            ClauseElements may have special equality comparisons which
+            makes us rely on them having unique hash codes for use in
+            hash-based collections. Stock __hash__ doesn't guarantee
+            unique values on platforms with moving GCs.
+            """
+            return id(self)
+
     def _negate(self):
         if hasattr(self, 'negation_clause'):
             return self.negation_clause
         else:
-            return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None)
+            return _UnaryExpression(
+                        self.self_group(against=operators.inv), 
+                        operator=operators.inv, 
+                        negate=None)
 
     def __repr__(self):
         friendly = getattr(self, 'description', None)
@@ -1211,6 +1292,12 @@ class ClauseElement(Visitable):
 class _Immutable(object):
     """mark a ClauseElement as 'immutable' when expressions are cloned."""
 
+    def unique_params(self, *optionaldict, **kwargs):
+        raise NotImplementedError("Immutable objects do not support copying")
+
+    def params(self, *optionaldict, **kwargs):
+        raise NotImplementedError("Immutable objects do not support copying")
+
     def _clone(self):
         return self
 
@@ -1330,6 +1417,9 @@ class ColumnOperators(Operators):
     def __truediv__(self, other):
         return self.operate(operators.truediv, other)
 
+    def __rtruediv__(self, other):
+        return self.reverse_operate(operators.truediv, other)
+
 class _CompareMixin(ColumnOperators):
     """Defines comparison and math operations for ``ClauseElement`` instances."""
 
@@ -1365,7 +1455,9 @@ class _CompareMixin(ColumnOperators):
         operators.add : (__operate,),
         operators.mul : (__operate,),
         operators.sub : (__operate,),
+        # Py2K
         operators.div : (__operate,),
+        # end Py2K
         operators.mod : (__operate,),
         operators.truediv : (__operate,),
         operators.lt : (__compare, operators.ge),
@@ -1632,7 +1724,7 @@ class ColumnCollection(util.OrderedProperties):
 
     def __init__(self, *cols):
         super(ColumnCollection, self).__init__()
-        [self.add(c) for c in cols]
+        self.update((c.key, c) for c in cols)
 
     def __str__(self):
         return repr([str(c) for c in self])
@@ -1734,8 +1826,10 @@ class Selectable(ClauseElement):
     __visit_name__ = 'selectable'
 
 class FromClause(Selectable):
-    """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
-
+    """Represent an element that can be used within the ``FROM`` 
+    clause of a ``SELECT`` statement.
+    
+    """
     __visit_name__ = 'fromclause'
     named_with_column = False
     _hide_froms = []
@@ -1749,7 +1843,11 @@ class FromClause(Selectable):
             col = list(self.primary_key)[0]
         else:
             col = list(self.columns)[0]
-        return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
+        return select(
+                    [func.count(col).label('tbl_row_count')], 
+                    whereclause, 
+                    from_obj=[self], 
+                    **params)
 
     def select(self, whereclause=None, **params):
         """return a SELECT of this ``FromClause``."""
@@ -1794,8 +1892,10 @@ class FromClause(Selectable):
         return fromclause in self._cloned_set
 
     def replace_selectable(self, old, alias):
-        """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
-
+        """replace all occurences of FromClause 'old' with the given Alias 
+        object, returning a copy of this ``FromClause``.
+        
+        """
         global ClauseAdapter
         if ClauseAdapter is None:
             from sqlalchemy.sql.util import ClauseAdapter
@@ -1846,24 +1946,30 @@ class FromClause(Selectable):
                     col, intersect = c, i
                 elif len(i) > len(intersect):
                     # 'c' has a larger field of correspondence than 'col'.
-                    # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches a1.c.x->table.c.x better than 
+                    # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches
+                    # a1.c.x->table.c.x better than 
                     # selectable.c.x->table.c.x does.
                     col, intersect = c, i
                 elif i == intersect:
                     # they have the same field of correspondence.
-                    # see which proxy_set has fewer columns in it, which indicates a
-                    # closer relationship with the root column.  Also take into account the 
-                    # "weight" attribute which CompoundSelect() uses to give higher precedence to
-                    # columns based on vertical position in the compound statement, and discard columns
-                    # that have no reference to the target column (also occurs with CompoundSelect)
+                    # see which proxy_set has fewer columns in it, which indicates
+                    # a closer relationship with the root column. Also take into
+                    # account the "weight" attribute which CompoundSelect() uses to
+                    # give higher precedence to columns based on vertical position
+                    # in the compound statement, and discard columns that have no
+                    # reference to the target column (also occurs with
+                    # CompoundSelect)
                     col_distance = util.reduce(operator.add, 
-                                        [sc._annotations.get('weight', 1) for sc in col.proxy_set if sc.shares_lineage(column)]
+                                        [sc._annotations.get('weight', 1) 
+                                            for sc in col.proxy_set 
+                                            if sc.shares_lineage(column)]
                                     )
                     c_distance = util.reduce(operator.add, 
-                                        [sc._annotations.get('weight', 1) for sc in c.proxy_set if sc.shares_lineage(column)]
+                                        [sc._annotations.get('weight', 1) 
+                                            for sc in c.proxy_set 
+                                            if sc.shares_lineage(column)]
                                     )
-                    if \
-                        c_distance < col_distance:
+                    if c_distance < col_distance:
                         col, intersect = c, i
         return col
 
@@ -2011,7 +2117,9 @@ class _BindParamClause(ColumnElement):
         the same type.
 
         """
-        return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ and self.value == other.value
+        return isinstance(other, _BindParamClause) and \
+                    other.type.__class__ == self.type.__class__ and \
+                    self.value == other.value
 
     def __getstate__(self):
         """execute a deferred value for serialization purposes."""
@@ -2024,7 +2132,9 @@ class _BindParamClause(ColumnElement):
         return d
 
     def __repr__(self):
-        return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+        return "_BindParamClause(%r, %r, type_=%r)" % (
+            self.key, self.value, self.type
+            )
 
 class _TypeClause(ClauseElement):
     """Handle a type keyword in a SQL statement.
@@ -2057,7 +2167,8 @@ class _TextClause(ClauseElement):
 
     _hide_froms = []
 
-    def __init__(self, text = "", bind=None, bindparams=None, typemap=None, autocommit=False):
+    def __init__(self, text = "", bind=None, 
+                    bindparams=None, typemap=None, autocommit=False):
         self._bind = bind
         self.bindparams = {}
         self.typemap = typemap
@@ -2157,7 +2268,8 @@ class ClauseList(ClauseElement):
         return list(itertools.chain(*[c._from_objects for c in self.clauses]))
 
     def self_group(self, against=None):
-        if self.group and self.operator is not against and operators.is_precedent(self.operator, against):
+        if self.group and self.operator is not against and \
+                operators.is_precedent(self.operator, against):
             return _Grouping(self)
         else:
             return self
@@ -2200,9 +2312,13 @@ class _Case(ColumnElement):
             pass
 
         if value:
-            whenlist = [(_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens]
+            whenlist = [
+                (_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens
+            ]
         else:
-            whenlist = [(_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens]
+            whenlist = [
+                (_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens
+            ]
 
         if whenlist:
             type_ = list(whenlist[-1])[-1].type
@@ -2472,16 +2588,19 @@ class _Exists(_UnaryExpression):
         return e
 
     def select_from(self, clause):
-        """return a new exists() construct with the given expression set as its FROM clause."""
-
+        """return a new exists() construct with the given expression set as its FROM
+        clause.
+        
+        """
         e = self._clone()
         e.element = self.element.select_from(clause).self_group()
         return e
 
     def where(self, clause):
-        """return a new exists() construct with the given expression added to its WHERE clause, joined
-        to the existing clause via AND, if any."""
-
+        """return a new exists() construct with the given expression added to its WHERE
+        clause, joined to the existing clause via AND, if any.
+        
+        """
         e = self._clone()
         e.element = self.element.where(clause).self_group()
         return e
@@ -2517,7 +2636,9 @@ class Join(FromClause):
             id(self.right))
 
     def is_derived_from(self, fromclause):
-        return fromclause is self or self.left.is_derived_from(fromclause) or self.right.is_derived_from(fromclause)
+        return fromclause is self or \
+                self.left.is_derived_from(fromclause) or\
+                self.right.is_derived_from(fromclause)
 
     def self_group(self, against=None):
         return _FromGrouping(self)
@@ -2634,7 +2755,11 @@ class Alias(FromClause):
 
     @property
     def description(self):
+        # Py3K
+        #return self.name
+        # Py2K
         return self.name.encode('ascii', 'backslashreplace')
+        # end Py2K
 
     def as_scalar(self):
         try:
@@ -2762,14 +2887,19 @@ class _Label(ColumnElement):
     def __init__(self, name, element, type_=None):
         while isinstance(element, _Label):
             element = element.element
-        self.name = self.key = self._label = name or _generated_label("%%(%d %s)s" % (id(self), getattr(element, 'name', 'anon')))
+        self.name = self.key = self._label = name or \
+                                _generated_label("%%(%d %s)s" % (
+                                    id(self), getattr(element, 'name', 'anon'))
+                                )
         self._element = element
         self._type = type_
         self.quote = element.quote
 
     @util.memoized_property
     def type(self):
-        return sqltypes.to_instance(self._type or getattr(self._element, 'type', None))
+        return sqltypes.to_instance(
+                    self._type or getattr(self._element, 'type', None)
+                )
 
     @util.memoized_property
     def element(self):
@@ -2842,7 +2972,11 @@ class ColumnClause(_Immutable, ColumnElement):
 
     @util.memoized_property
     def description(self):
+        # Py3K
+        #return self.name
+        # Py2K
         return self.name.encode('ascii', 'backslashreplace')
+        # end Py2K
 
     @util.memoized_property
     def _label(self):
@@ -2891,7 +3025,12 @@ class ColumnClause(_Immutable, ColumnElement):
         # propagate the "is_literal" flag only if we are keeping our name,
         # otherwise its considered to be a label
         is_literal = self.is_literal and (name is None or name == self.name)
-        c = ColumnClause(name or self.name, selectable=selectable, type_=self.type, is_literal=is_literal)
+        c = ColumnClause(
+                    name or self.name, 
+                    selectable=selectable, 
+                    type_=self.type, 
+                    is_literal=is_literal
+                )
         c.proxies = [self]
         if attach:
             selectable.columns[c.name] = c
@@ -2927,7 +3066,11 @@ class TableClause(_Immutable, FromClause):
 
     @util.memoized_property
     def description(self):
+        # Py3K
+        #return self.name
+        # Py2K
         return self.name.encode('ascii', 'backslashreplace')
+        # end Py2K
 
     def append_column(self, c):
         self._columns[c.name] = c
@@ -2944,7 +3087,11 @@ class TableClause(_Immutable, FromClause):
             col = list(self.primary_key)[0]
         else:
             col = list(self.columns)[0]
-        return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
+        return select(
+                    [func.count(col).label('tbl_row_count')], 
+                    whereclause, 
+                    from_obj=[self], 
+                    **params)
 
     def insert(self, values=None, inline=False, **kwargs):
         """Generate an :func:`~sqlalchemy.sql.expression.insert()` construct."""
@@ -2954,7 +3101,8 @@ class TableClause(_Immutable, FromClause):
     def update(self, whereclause=None, values=None, inline=False, **kwargs):
         """Generate an :func:`~sqlalchemy.sql.expression.update()` construct."""
 
-        return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs)
+        return update(self, whereclause=whereclause, 
+                            values=values, inline=inline, **kwargs)
 
     def delete(self, whereclause=None, **kwargs):
         """Generate a :func:`~sqlalchemy.sql.expression.delete()` construct."""
@@ -3004,7 +3152,8 @@ class _SelectBaseMixin(object):
         Typically, a select statement which has only one column in its columns clause
         is eligible to be used as a scalar expression.
 
-        The returned object is an instance of :class:`~sqlalchemy.sql.expression._ScalarSelect`.
+        The returned object is an instance of 
+        :class:`~sqlalchemy.sql.expression._ScalarSelect`.
 
         """
         return _ScalarSelect(self)
@@ -3013,10 +3162,10 @@ class _SelectBaseMixin(object):
     def apply_labels(self):
         """return a new selectable with the 'use_labels' flag set to True.
 
-        This will result in column expressions being generated using labels against their table
-        name, such as "SELECT somecolumn AS tablename_somecolumn".  This allows selectables which
-        contain multiple FROM clauses to produce a unique set of column names regardless of name conflicts
-        among the individual FROM clauses.
+        This will result in column expressions being generated using labels against their
+        table name, such as "SELECT somecolumn AS tablename_somecolumn". This allows
+        selectables which contain multiple FROM clauses to produce a unique set of column
+        names regardless of name conflicts among the individual FROM clauses.
 
         """
         self.use_labels = True
@@ -3127,7 +3276,8 @@ class _ScalarSelect(_Grouping):
         return list(self.inner_columns)[0]._make_proxy(selectable, name)
 
 class CompoundSelect(_SelectBaseMixin, FromClause):
-    """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations."""
+    """Forms the basis of ``UNION``, ``UNION ALL``, and other 
+        SELECT-based set operations."""
 
     __visit_name__ = 'compound_select'
 
@@ -3147,7 +3297,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
             elif len(s.c) != numcols:
                 raise exc.ArgumentError(
                         "All selectables passed to CompoundSelect must "
-                        "have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
+                        "have identical numbers of columns; select #%d has %d columns,"
+                        " select #%d has %d" %
                         (1, len(self.selects[0].c), n+1, len(s.c))
                 )
 
@@ -3222,7 +3373,15 @@ class Select(_SelectBaseMixin, FromClause):
 
     __visit_name__ = 'select'
 
-    def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs):
+    def __init__(self, 
+                columns, 
+                whereclause=None, 
+                from_obj=None, 
+                distinct=False, 
+                having=None, 
+                correlate=True, 
+                prefixes=None, 
+                **kwargs):
         """Construct a Select object.
 
         The public constructor for Select is the
@@ -3241,9 +3400,9 @@ class Select(_SelectBaseMixin, FromClause):
 
         if columns:
             self._raw_columns = [
-                isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
-                for c in
-                [_literal_as_column(c) for c in columns]
+                isinstance(c, _ScalarSelect) and 
+                c.self_group(against=operators.comma_op) or c
+                for c in [_literal_as_column(c) for c in columns]
             ]
 
             self._froms.update(_from_objects(*self._raw_columns))
@@ -3331,8 +3490,7 @@ class Select(_SelectBaseMixin, FromClause):
         be rendered into the columns clause of the resulting SELECT statement.
 
         """
-
-        return itertools.chain(*[c._select_iterable for c in self._raw_columns])
+        return _select_iterables(self._raw_columns)
 
     def is_derived_from(self, fromclause):
         if self in fromclause._cloned_set:
@@ -3347,7 +3505,7 @@ class Select(_SelectBaseMixin, FromClause):
         self._reset_exported()
         from_cloned = dict((f, clone(f))
                            for f in self._froms.union(self._correlate))
-        self._froms = set(from_cloned[f] for f in self._froms)
+        self._froms = util.OrderedSet(from_cloned[f] for f in self._froms)
         self._correlate = set(from_cloned[f] for f in self._correlate)
         self._raw_columns = [clone(c) for c in self._raw_columns]
         for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
@@ -3359,11 +3517,17 @@ class Select(_SelectBaseMixin, FromClause):
 
         return (column_collections and list(self.columns) or []) + \
             self._raw_columns + list(self._froms) + \
-            [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+            [x for x in 
+                (self._whereclause, self._having, 
+                    self._order_by_clause, self._group_by_clause) 
+            if x is not None]
 
     @_generative
     def column(self, column):
-        """return a new select() construct with the given column expression added to its columns clause."""
+        """return a new select() construct with the given column expression 
+            added to its columns clause.
+            
+        """
 
         column = _literal_as_column(column)
 
@@ -3375,63 +3539,73 @@ class Select(_SelectBaseMixin, FromClause):
 
     @_generative
     def with_only_columns(self, columns):
-        """return a new select() construct with its columns clause replaced with the given columns."""
+        """return a new select() construct with its columns clause replaced 
+            with the given columns.
+            
+        """
 
         self._raw_columns = [
-                isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
-                for c in
-                [_literal_as_column(c) for c in columns]
+                isinstance(c, _ScalarSelect) and 
+                c.self_group(against=operators.comma_op) or c
+                for c in [_literal_as_column(c) for c in columns]
             ]
 
     @_generative
     def where(self, whereclause):
-        """return a new select() construct with the given expression added to its WHERE clause, joined
-        to the existing clause via AND, if any."""
+        """return a new select() construct with the given expression added to its 
+        WHERE clause, joined to the existing clause via AND, if any.
+        
+        """
 
         self.append_whereclause(whereclause)
 
     @_generative
     def having(self, having):
-        """return a new select() construct with the given expression added to its HAVING clause, joined
-        to the existing clause via AND, if any."""
-
+        """return a new select() construct with the given expression added to its HAVING
+          clause, joined to the existing clause via AND, if any.
+         
+        """
         self.append_having(having)
 
     @_generative
     def distinct(self):
-        """return a new select() construct which will apply DISTINCT to its columns clause."""
-
+        """return a new select() construct which will apply DISTINCT to its columns
+         clause.
+         
+         """
         self._distinct = True
 
     @_generative
     def prefix_with(self, clause):
-        """return a new select() construct which will apply the given expression to the start of its
-        columns clause, not using any commas."""
+        """return a new select() construct which will apply the given expression to the
+         start of its columns clause, not using any commas.
 
+         """
         clause = _literal_as_text(clause)
         self._prefixes = self._prefixes + [clause]
 
     @_generative
     def select_from(self, fromclause):
-        """return a new select() construct with the given FROM expression applied to its list of
-        FROM objects."""
+        """return a new select() construct with the given FROM expression applied to its
+         list of FROM objects.
 
+         """
         fromclause = _literal_as_text(fromclause)
         self._froms = self._froms.union([fromclause])
 
     @_generative
     def correlate(self, *fromclauses):
-        """return a new select() construct which will correlate the given FROM clauses to that
-        of an enclosing select(), if a match is found.
-
-        By "match", the given fromclause must be present in this select's list of FROM objects
-        and also present in an enclosing select's list of FROM objects.
-
-        Calling this method turns off the select's default behavior of "auto-correlation".  Normally,
-        select() auto-correlates all of its FROM clauses to those of an embedded select when
-        compiled.
-
-        If the fromclause is None, correlation is disabled for the returned select().
+        """return a new select() construct which will correlate the given FROM clauses to
+        that of an enclosing select(), if a match is found.
+        
+         By "match", the given fromclause must be present in this select's list of FROM
+        objects and also present in an enclosing select's list of FROM objects.
+        
+         Calling this method turns off the select's default behavior of
+        "auto-correlation". Normally, select() auto-correlates all of its FROM clauses to
+        those of an embedded select when compiled.
+        
+         If the fromclause is None, correlation is disabled for the returned select().
 
         """
         self._should_correlate = False
@@ -3447,8 +3621,10 @@ class Select(_SelectBaseMixin, FromClause):
         self._correlate = self._correlate.union([fromclause])
 
     def append_column(self, column):
-        """append the given column expression to the columns clause of this select() construct."""
-
+        """append the given column expression to the columns clause of this select()
+        construct.
+        
+        """
         column = _literal_as_column(column)
 
         if isinstance(column, _ScalarSelect):
@@ -3459,8 +3635,10 @@ class Select(_SelectBaseMixin, FromClause):
         self._reset_exported()
 
     def append_prefix(self, clause):
-        """append the given columns clause prefix expression to this select() construct."""
-
+        """append the given columns clause prefix expression to this select()
+        construct.
+        
+        """
         clause = _literal_as_text(clause)
         self._prefixes = self._prefixes.union([clause])
 
@@ -3490,7 +3668,8 @@ class Select(_SelectBaseMixin, FromClause):
             self._having = _literal_as_text(having)
 
     def append_from(self, fromclause):
-        """append the given FromClause expression to this select() construct's FROM clause.
+        """append the given FromClause expression to this select() construct's FROM
+        clause.
 
         """
         if _is_literal(fromclause):
@@ -3529,8 +3708,10 @@ class Select(_SelectBaseMixin, FromClause):
         return union(self, other, **kwargs)
 
     def union_all(self, other, **kwargs):
-        """return a SQL UNION ALL of this select() construct against the given selectable."""
-
+        """return a SQL UNION ALL of this select() construct against the given
+        selectable.
+        
+        """
         return union_all(self, other, **kwargs)
 
     def except_(self, other, **kwargs):
@@ -3539,18 +3720,24 @@ class Select(_SelectBaseMixin, FromClause):
         return except_(self, other, **kwargs)
 
     def except_all(self, other, **kwargs):
-        """return a SQL EXCEPT ALL of this select() construct against the given selectable."""
-
+        """return a SQL EXCEPT ALL of this select() construct against the given
+        selectable.
+        
+        """
         return except_all(self, other, **kwargs)
 
     def intersect(self, other, **kwargs):
-        """return a SQL INTERSECT of this select() construct against the given selectable."""
-
+        """return a SQL INTERSECT of this select() construct against the given
+        selectable.
+        
+        """
         return intersect(self, other, **kwargs)
 
     def intersect_all(self, other, **kwargs):
-        """return a SQL INTERSECT ALL of this select() construct against the given selectable."""
-
+        """return a SQL INTERSECT ALL of this select() construct against the given
+        selectable.
+        
+        """
         return intersect_all(self, other, **kwargs)
 
     def bind(self):
@@ -3581,7 +3768,7 @@ class _UpdateBase(ClauseElement):
 
     supports_execution = True
     _autocommit = True
-
+    
     def _generate(self):
         s = self.__class__.__new__(self.__class__)
         s.__dict__ = self.__dict__.copy()
@@ -3597,8 +3784,10 @@ class _UpdateBase(ClauseElement):
             return parameters
 
     def params(self, *arg, **kw):
-        raise NotImplementedError("params() is not supported for INSERT/UPDATE/DELETE statements."
-            "  To set the values for an INSERT or UPDATE statement, use stmt.values(**parameters).")
+        raise NotImplementedError(
+            "params() is not supported for INSERT/UPDATE/DELETE statements."
+            " To set the values for an INSERT or UPDATE statement, use"
+            " stmt.values(**parameters).")
 
     def bind(self):
         return self._bind or self.table.bind
@@ -3607,6 +3796,51 @@ class _UpdateBase(ClauseElement):
         self._bind = bind
     bind = property(bind, _set_bind)
 
+    _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning')
+    def _process_deprecated_kw(self, kwargs):
+        for k in list(kwargs):
+            m = self._returning_re.match(k)
+            if m:
+                self._returning = kwargs.pop(k)
+                util.warn_deprecated(
+                    "The %r argument is deprecated.  Please use statement.returning(col1, col2, ...)" % k
+                )
+        return kwargs
+    
+    @_generative
+    def returning(self, *cols):
+        """Add a RETURNING or equivalent clause to this statement.
+        
+        The given list of columns represent columns within the table
+        that is the target of the INSERT, UPDATE, or DELETE.  Each 
+        element can be any column expression.  ``Table`` objects
+        will be expanded into their individual columns.
+        
+        Upon compilation, a RETURNING clause, or database equivalent, 
+        will be rendered within the statement.   For INSERT and UPDATE, 
+        the values are the newly inserted/updated values.  For DELETE, 
+        the values are those of the rows which were deleted.
+        
+        Upon execution, the values of the columns to be returned
+        are made available via the result set and can be iterated
+        using ``fetchone()`` and similar.   For DBAPIs which do not
+        natively support returning values (i.e. cx_oracle), 
+        SQLAlchemy will approximate this behavior at the result level
+        so that a reasonable amount of behavioral neutrality is 
+        provided.
+        
+        Note that not all databases/DBAPIs
+        support RETURNING.   For those backends with no support,
+        an exception is raised upon compilation and/or execution.
+        For those who do support it, the functionality across backends
+        varies greatly, including restrictions on executemany()
+        and other statements which return multiple rows. Please 
+        read the documentation notes for the database in use in 
+        order to determine the availability of RETURNING.
+        
+        """
+        self._returning = cols
+        
 class _ValuesBase(_UpdateBase):
 
     __visit_name__ = 'values_base'
@@ -3617,14 +3851,15 @@ class _ValuesBase(_UpdateBase):
 
     @_generative
     def values(self, *args, **kwargs):
-        """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE.
+        """specify the VALUES clause for an INSERT statement, or the SET clause for an
+        UPDATE.
 
             \**kwargs
                 key=<somevalue> arguments
 
             \*args
-                A single dictionary can be sent as the first positional argument.  This allows
-                non-string based keys, such as Column objects, to be used.
+                A single dictionary can be sent as the first positional argument. This
+                allows non-string based keys, such as Column objects, to be used.
 
         """
         if args:
@@ -3648,16 +3883,25 @@ class Insert(_ValuesBase):
     """
     __visit_name__ = 'insert'
 
-    def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs):
+    def __init__(self, 
+                table, 
+                values=None, 
+                inline=False, 
+                bind=None, 
+                prefixes=None, 
+                returning=None,
+                **kwargs):
         _ValuesBase.__init__(self, table, values)
         self._bind = bind
         self.select = None
         self.inline = inline
+        self._returning = returning
         if prefixes:
             self._prefixes = [_literal_as_text(p) for p in prefixes]
         else:
             self._prefixes = []
-        self.kwargs = kwargs
+            
+        self.kwargs = self._process_deprecated_kw(kwargs)
 
     def get_children(self, **kwargs):
         if self.select is not None:
@@ -3688,15 +3932,24 @@ class Update(_ValuesBase):
     """
     __visit_name__ = 'update'
 
-    def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
+    def __init__(self, 
+                table, 
+                whereclause, 
+                values=None, 
+                inline=False, 
+                bind=None, 
+                returning=None,
+                **kwargs):
         _ValuesBase.__init__(self, table, values)
         self._bind = bind
+        self._returning = returning
         if whereclause:
             self._whereclause = _literal_as_text(whereclause)
         else:
             self._whereclause = None
         self.inline = inline
-        self.kwargs = kwargs
+
+        self.kwargs = self._process_deprecated_kw(kwargs)
 
     def get_children(self, **kwargs):
         if self._whereclause is not None:
@@ -3711,9 +3964,10 @@ class Update(_ValuesBase):
 
     @_generative
     def where(self, whereclause):
-        """return a new update() construct with the given expression added to its WHERE clause, joined
-        to the existing clause via AND, if any."""
-
+        """return a new update() construct with the given expression added to its WHERE
+        clause, joined to the existing clause via AND, if any.
+        
+        """
         if self._whereclause is not None:
             self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
         else:
@@ -3729,15 +3983,22 @@ class Delete(_UpdateBase):
 
     __visit_name__ = 'delete'
 
-    def __init__(self, table, whereclause, bind=None, **kwargs):
+    def __init__(self, 
+            table, 
+            whereclause, 
+            bind=None, 
+            returning =None,
+            **kwargs):
         self._bind = bind
         self.table = table
+        self._returning = returning
+        
         if whereclause:
             self._whereclause = _literal_as_text(whereclause)
         else:
             self._whereclause = None
 
-        self.kwargs = kwargs
+        self.kwargs = self._process_deprecated_kw(kwargs)
 
     def get_children(self, **kwargs):
         if self._whereclause is not None:
index 7c21e8233a9851cb57d75d95c754f88494c89f0a..879f0f3e517c0ecb2970e4684fef8d56a5df9979 100644 (file)
@@ -4,8 +4,13 @@
 """Defines operators used in SQL expressions."""
 
 from operator import (
-    and_, or_, inv, add, mul, sub, div, mod, truediv, lt, le, ne, gt, ge, eq
+    and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq
     )
+    
+# Py2K
+from operator import (div,)
+# end Py2K
+
 from sqlalchemy.util import symbol
 
 
@@ -88,7 +93,10 @@ _largest = symbol('_largest')
 _PRECEDENCE = {
     from_: 15,
     mul: 7,
+    truediv: 7,
+    # Py2K
     div: 7,
+    # end Py2K
     mod: 7,
     add: 6,
     sub: 6,
index a5bd497aedf9bc59f2e1c36d7a549da89d406884..4471d4fb0d3425c87f6e03ecfacca3dbc97903f3 100644 (file)
@@ -34,13 +34,10 @@ class VisitableType(type):
     """
     
     def __init__(cls, clsname, bases, clsdict):
-        if cls.__name__ == 'Visitable':
+        if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
             super(VisitableType, cls).__init__(clsname, bases, clsdict)
             return
         
-        assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \
-                                               "should define `__visit_name__`"
-        
         # set up an optimized visit dispatch function
         # for use by the compiler
         visit_name = cls.__visit_name__
index dc2c6d40f8a8c38058ebacb9c6eaefb3ad4cc3f8..1af28794eda0cbc21145aae77ac820191ed77443 100644 (file)
@@ -3,7 +3,6 @@ from sqlalchemy.interfaces import ConnectionProxy
 from sqlalchemy.engine.default import DefaultDialect
 from sqlalchemy.engine.base import Connection
 from sqlalchemy import util
-import testing
 import re
 
 class AssertRule(object):
index 6ea5667cc3357e9ed68632011b599f8014dd47ce..eec962d807af2b6304d718e46082f10f90c9e188 100644 (file)
@@ -1,4 +1,8 @@
-import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+import optparse, os, sys, re, ConfigParser, time, warnings
+
+# 2to3
+import StringIO
+
 logging = None
 
 __all__ = 'parser', 'configure', 'options',
@@ -13,7 +17,11 @@ base_config = """
 [db]
 sqlite=sqlite:///:memory:
 sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
+postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
+postgres=postgresql://scott:tiger@127.0.0.1:5432/test
+pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
+postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
 mysql=mysql://scott:tiger@127.0.0.1:3306/test
 oracle=oracle://scott:tiger@127.0.0.1:1521
 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
@@ -125,28 +133,22 @@ def _prep_testing_database(options, file_config):
     from sqlalchemy.test import engines
     from sqlalchemy import schema
 
-    try:
-        # also create alt schemas etc. here?
-        if options.dropfirst:
-            e = engines.utf8_engine()
-            existing = e.table_names()
-            if existing:
-                print "Dropping existing tables in database: " + db_url
-                try:
-                    print "Tables: %s" % ', '.join(existing)
-                except:
-                    pass
-                print "Abort within 5 seconds..."
-                time.sleep(5)
-                md = schema.MetaData(e, reflect=True)
-                md.drop_all()
-            e.dispose()
-    except (KeyboardInterrupt, SystemExit):
-        raise
-    except Exception, e:
-        warnings.warn(RuntimeWarning(
-            "Error checking for existing tables in testing "
-            "database: %s" % e))
+    # also create alt schemas etc. here?
+    if options.dropfirst:
+        e = engines.utf8_engine()
+        existing = e.table_names()
+        if existing:
+            print "Dropping existing tables in database: " + db_url
+            try:
+                print "Tables: %s" % ', '.join(existing)
+            except:
+                pass
+            print "Abort within 5 seconds..."
+            time.sleep(5)
+            md = schema.MetaData(e, reflect=True)
+            md.drop_all()
+        e.dispose()
+
 post_configure['prep_db'] = _prep_testing_database
 
 def _set_table_options(options, file_config):
index f0001978bf4214428fe592fc2d773bf2435ff833..187ad2ff036a40d8c8739cc0d941e5171423747f 100644 (file)
@@ -2,6 +2,7 @@ import sys, types, weakref
 from collections import deque
 import config
 from sqlalchemy.util import function_named, callable
+import re
 
 class ConnectionKiller(object):
     def __init__(self):
@@ -11,7 +12,8 @@ class ConnectionKiller(object):
         self.proxy_refs[con_proxy] = True
 
     def _apply_all(self, methods):
-        for rec in self.proxy_refs:
+        # must copy keys atomically
+        for rec in self.proxy_refs.keys():
             if rec is not None and rec.is_valid:
                 try:
                     for name in methods:
@@ -38,6 +40,10 @@ class ConnectionKiller(object):
 
 testing_reaper = ConnectionKiller()
 
+def drop_all_tables(metadata):
+    testing_reaper.close_all()
+    metadata.drop_all()
+    
 def assert_conns_closed(fn):
     def decorated(*args, **kw):
         try:
@@ -56,6 +62,14 @@ def rollback_open_connections(fn):
             testing_reaper.rollback_all()
     return function_named(decorated, fn.__name__)
 
+def close_first(fn):
+    """Decorator that closes all connections before fn execution."""
+    def decorated(*args, **kw):
+        testing_reaper.close_all()
+        fn(*args, **kw)
+    return function_named(decorated, fn.__name__)
+    
+    
 def close_open_connections(fn):
     """Decorator that closes all connections after fn execution."""
 
@@ -69,7 +83,10 @@ def close_open_connections(fn):
 def all_dialects():
     import sqlalchemy.databases as d
     for name in d.__all__:
-        mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+        # TEMPORARY
+        mod = getattr(d, name, None)
+        if not mod:
+            mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
         yield mod.dialect()
         
 class ReconnectFixture(object):
@@ -115,7 +132,11 @@ def testing_engine(url=None, options=None):
     listeners.append(testing_reaper)
 
     engine = create_engine(url, **options)
-
+    
+    # may want to call this, results
+    # in first-connect initializers
+    #engine.connect()
+    
     return engine
 
 def utf8_engine(url=None, options=None):
@@ -123,7 +144,7 @@ def utf8_engine(url=None, options=None):
 
     from sqlalchemy.engine import url as engine_url
 
-    if config.db.name == 'mysql':
+    if config.db.driver == 'mysqldb':
         dbapi_ver = config.db.dialect.dbapi.version_info
         if (dbapi_ver < (1, 2, 1) or
             dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
@@ -139,19 +160,35 @@ def utf8_engine(url=None, options=None):
 
     return testing_engine(url, options)
 
-def mock_engine(db=None):
-    """Provides a mocking engine based on the current testing.db."""
+def mock_engine(dialect_name=None):
+    """Provides a mocking engine based on the current testing.db.
+    
+    This is normally used to test DDL generation flow as emitted
+    by an Engine.
+    
+    It should not be used in other cases, as assert_compile() and
+    assert_sql_execution() are much better choices with fewer 
+    moving parts.
+    
+    """
     
     from sqlalchemy import create_engine
     
-    dbi = db or config.db
+    if not dialect_name:
+        dialect_name = config.db.name
+
     buffer = []
     def executor(sql, *a, **kw):
         buffer.append(sql)
-    engine = create_engine(dbi.name + '://',
+    def assert_sql(stmts):
+        recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+        assert  recv == stmts, recv
+        
+    engine = create_engine(dialect_name + '://',
                            strategy='mock', executor=executor)
     assert not hasattr(engine, 'mock')
     engine.mock = buffer
+    engine.assert_sql = assert_sql
     return engine
 
 class ReplayableSession(object):
@@ -168,9 +205,16 @@ class ReplayableSession(object):
     Natives = set([getattr(types, t)
                    for t in dir(types) if not t.startswith('_')]). \
                    difference([getattr(types, t)
+                            # Py3K
+                            #for t in ('FunctionType', 'BuiltinFunctionType',
+                            #          'MethodType', 'BuiltinMethodType',
+                            #          'LambdaType', )])
+                            
+                            # Py2K
                                for t in ('FunctionType', 'BuiltinFunctionType',
                                          'MethodType', 'BuiltinMethodType',
                                          'LambdaType', 'UnboundMethodType',)])
+                            # end Py2K
     def __init__(self):
         self.buffer = deque()
 
index 263d2d783138bc35976f8e32700d97dd933fcd9f..c4f32a1630a86c09ecf2b84926dc474c277d7ba0 100644 (file)
@@ -14,7 +14,7 @@ from config import db, db_label, db_url, file_config, base_config, \
                            _set_table_options, _reverse_topological, _log
 from sqlalchemy.test import testing, config, requires
 from nose.plugins import Plugin
-from nose.util import tolist
+from sqlalchemy import util
 import nose.case
 
 log = logging.getLogger('nose.plugins.sqlalchemy')
@@ -30,9 +30,6 @@ class NoseSQLAlchemy(Plugin):
     def options(self, parser, env=os.environ):
         Plugin.options(self, parser, env)
         opt = parser.add_option
-        #opt("--verbose", action="store_true", dest="verbose",
-            #help="enable stdout echoing/printing")
-        #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
         opt("--log-info", action="callback", type="string", callback=_log,
             help="turn on info logging for <LOG> (multiple OK)")
         opt("--log-debug", action="callback", type="string", callback=_log,
@@ -77,15 +74,16 @@ class NoseSQLAlchemy(Plugin):
         
     def configure(self, options, conf):
         Plugin.configure(self, options, conf)
-
-        import testing, requires
+        self.options = options
+        
+    def begin(self):
         testing.db = db
         testing.requires = requires
 
         # Lazy setup of other options (post coverage)
         for fn in post_configure:
-            fn(options, file_config)
-        
+            fn(self.options, file_config)
+
     def describeTest(self, test):
         return ""
         
@@ -117,15 +115,20 @@ class NoseSQLAlchemy(Plugin):
                 if check(test_suite)() != 'ok':
                     # The requirement will perform messaging.
                     return True
-        if (hasattr(cls, '__unsupported_on__') and
-            testing.db.name in cls.__unsupported_on__):
-            print "'%s' unsupported on DB implementation '%s'" % (
-                cls.__class__.__name__, testing.db.name)
-            return True
-        if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
-            print "'%s' unsupported on DB implementation '%s'" % (
-                cls.__class__.__name__, testing.db.name)
-            return True
+
+        if cls.__unsupported_on__:
+            spec = testing.db_spec(*cls.__unsupported_on__)
+            if spec(testing.db):
+                print "'%s' unsupported on DB implementation '%s'" % (
+                     cls.__class__.__name__, testing.db.name)
+                return True
+        if getattr(cls, '__only_on__', None):
+            spec = testing.db_spec(*util.to_list(cls.__only_on__))
+            if not spec(testing.db):
+                print "'%s' unsupported on DB implementation '%s'" % (
+                     cls.__class__.__name__, testing.db.name)
+                return True                    
+
         if (getattr(cls, '__skip_if__', False)):
             for c in getattr(cls, '__skip_if__'):
                 if c():
@@ -140,15 +143,15 @@ class NoseSQLAlchemy(Plugin):
                 return True
         return False
 
-    #def begin(self):
-        #pass
-
     def beforeTest(self, test):
         testing.resetwarnings()
 
     def afterTest(self, test):
         testing.resetwarnings()
         
+    def afterContext(self):
+        testing.global_cleanup_assertions()
+        
     #def handleError(self, test, err):
         #pass
 
index ca4b31cbd8c1744973f3b190cc5552d7402fde3c..8cab6ceba1525686d333be32865b895372230622 100644 (file)
@@ -6,8 +6,9 @@ in a more fine-grained way than nose's profiling plugin.
 """
 
 import os, sys
-from sqlalchemy.util import function_named
-import config
+from sqlalchemy.test import config
+from sqlalchemy.test.util import function_named, gc_collect
+from nose import SkipTest
 
 __all__ = 'profiled', 'function_call_count', 'conditional_call_count'
 
@@ -162,15 +163,22 @@ def conditional_call_count(discriminator, categories):
 def _profile(filename, fn, *args, **kw):
     global profiler
     if not profiler:
-        profiler = 'hotshot'
         if sys.version_info > (2, 5):
             try:
                 import cProfile
                 profiler = 'cProfile'
             except ImportError:
                 pass
+        if not profiler:
+            try:
+                import hotshot
+                profiler = 'hotshot'
+            except ImportError:
+                profiler = 'skip'
 
-    if profiler == 'cProfile':
+    if profiler == 'skip':
+        raise SkipTest('Profiling not supported on this platform')
+    elif profiler == 'cProfile':
         return _profile_cProfile(filename, fn, *args, **kw)
     else:
         return _profile_hotshot(filename, fn, *args, **kw)
@@ -179,7 +187,7 @@ def _profile_cProfile(filename, fn, *args, **kw):
     import cProfile, gc, pstats, time
 
     load_stats = lambda: pstats.Stats(filename)
-    gc.collect()
+    gc_collect()
 
     began = time.time()
     cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
@@ -192,7 +200,7 @@ def _profile_hotshot(filename, fn, *args, **kw):
     import gc, hotshot, hotshot.stats, time
     load_stats = lambda: hotshot.stats.load(filename)
 
-    gc.collect()
+    gc_collect()
     prof = hotshot.Profile(filename)
     began = time.time()
     prof.start()
index b23b8620da054d762016ba9453aa6d2d8c3f7a14..f3f4ec1911c9f0a7243408d282429dee322ba1cd 100644 (file)
@@ -28,6 +28,25 @@ def foreign_keys(fn):
         no_support('sqlite', 'not supported by database'),
         )
 
+
+def unbounded_varchar(fn):
+    """Target database must support VARCHAR with no length"""
+    return _chain_decorators_on(
+        fn,
+        no_support('firebird', 'not supported by database'),
+        no_support('oracle', 'not supported by database'),
+        no_support('mysql', 'not supported by database'),
+    )
+
+def boolean_col_expressions(fn):
+    """Target database must support boolean expressions as columns"""
+    return _chain_decorators_on(
+        fn,
+        no_support('firebird', 'not supported by database'),
+        no_support('oracle', 'not supported by database'),
+        no_support('mssql', 'not supported by database'),
+    )
+    
 def identity(fn):
     """Target database must support GENERATED AS IDENTITY or a facsimile.
 
@@ -40,7 +59,7 @@ def identity(fn):
         fn,
         no_support('firebird', 'not supported by database'),
         no_support('oracle', 'not supported by database'),
-        no_support('postgres', 'not supported by database'),
+        no_support('postgresql', 'not supported by database'),
         no_support('sybase', 'not supported by database'),
         )
 
@@ -61,9 +80,19 @@ def row_triggers(fn):
         # no access to same table
         no_support('mysql', 'requires SUPER priv'),
         exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
-        no_support('postgres', 'not supported by database: no statements'),
+        
+        # huh?  TODO: implement triggers for PG tests, remove this
+        no_support('postgresql', 'PG triggers need to be implemented for tests'),  
         )
 
+def correlated_outer_joins(fn):
+    """Target must support an outer join to a subquery which correlates to the parent."""
+    
+    return _chain_decorators_on(
+        fn,
+        no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
+    )
+    
 def savepoints(fn):
     """Target database must support savepoints."""
     return _chain_decorators_on(
@@ -75,6 +104,15 @@ def savepoints(fn):
         exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
         )
 
+def schemas(fn):
+    """Target database must support external schemas, and have one named 'test_schema'."""
+    
+    return _chain_decorators_on(
+        fn,
+        no_support('sqlite', 'no schema support'),
+        no_support('firebird', 'no schema support')
+    )
+    
 def sequences(fn):
     """Target database must support SEQUENCEs."""
     return _chain_decorators_on(
@@ -93,6 +131,17 @@ def subqueries(fn):
         exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
         )
 
+def returning(fn):
+    return _chain_decorators_on(
+        fn,
+        no_support('access', 'not supported by database'),
+        no_support('sqlite', 'not supported by database'),
+        no_support('mysql', 'not supported by database'),
+        no_support('maxdb', 'not supported by database'),
+        no_support('sybase', 'not supported by database'),
+        no_support('informix', 'not supported by database'),
+    )
+    
 def two_phase_transactions(fn):
     """Target database must support two-phase transactions."""
     return _chain_decorators_on(
@@ -104,6 +153,8 @@ def two_phase_transactions(fn):
         no_support('oracle', 'no SA implementation'),
         no_support('sqlite', 'not supported by database'),
         no_support('sybase', 'FIXME: guessing, needs confirmation'),
+        no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may '
+                   'need separate XA implementation'),
         exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
         )
 
index f96805fe4947f9f57b750d169aecec14aff9ce35..35b4060d2bd7c06c9c93f1372dc846299edcbf1a 100644 (file)
@@ -33,7 +33,7 @@ def Table(*args, **kw):
         # expand to ForeignKeyConstraint too.
         fks = [fk
                for col in args if isinstance(col, schema.Column)
-               for fk in col.args if isinstance(fk, schema.ForeignKey)]
+               for fk in col.foreign_keys]
 
         for fk in fks:
             # root around in raw spec
@@ -51,13 +51,6 @@ def Table(*args, **kw):
                 if fk.onupdate is None:
                     fk.onupdate = 'CASCADE'
 
-    if testing.against('firebird', 'oracle'):
-        pk_seqs = [col for col in args
-                   if (isinstance(col, schema.Column)
-                       and col.primary_key
-                       and getattr(col, '_needs_autoincrement', False))]
-        for c in pk_seqs:
-            c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True))
     return schema.Table(*args, **kw)
 
 
@@ -67,8 +60,20 @@ def Column(*args, **kw):
     test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
                       if k.startswith('test_')])
 
-    c = schema.Column(*args, **kw)
-    if testing.against('firebird', 'oracle'):
-        if 'test_needs_autoincrement' in test_opts:
-            c._needs_autoincrement = True
-    return c
+    col = schema.Column(*args, **kw)
+    if 'test_needs_autoincrement' in test_opts and \
+        kw.get('primary_key', False) and \
+        testing.against('firebird', 'oracle'):
+        def add_seq(tbl):
+            col._init_items(
+                schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + col.name + '_seq'), optional=True)
+            )
+        col._on_table_attach(add_seq)
+    return col
+
+def _truncate_name(dialect, name):
+    if len(name) > dialect.max_identifier_length:
+        return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
+    else:
+        return name
+    
\ No newline at end of file
index 36c7d340a3bc8f80a3fe98cf1e6ce1442928294e..16a13d9d3b8f6bbd1b52a255593490b30227db97 100644 (file)
@@ -8,10 +8,12 @@ import types
 import warnings
 from cStringIO import StringIO
 
-from sqlalchemy.test import config, assertsql
+from sqlalchemy.test import config, assertsql, util as testutil
 from sqlalchemy.util import function_named
+from engines import drop_all_tables
 
-from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema
+from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool
+from nose import SkipTest
 
 _ops = { '<': operator.lt,
          '>': operator.gt,
@@ -80,6 +82,19 @@ def future(fn):
                 "Unexpected success for future test '%s'" % fn_name)
     return function_named(decorated, fn_name)
 
+def db_spec(*dbs):
+    dialects = set([x for x in dbs if '+' not in x])
+    drivers = set([x[1:] for x in dbs if x.startswith('+')])
+    specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
+
+    def check(engine):
+        return engine.name in dialects or \
+            engine.driver in drivers or \
+            (engine.name, engine.driver) in specs
+    
+    return check
+        
+
 def fails_on(dbs, reason):
     """Mark a test as expected to fail on the specified database 
     implementation.
@@ -90,23 +105,25 @@ def fails_on(dbs, reason):
     succeeds, a failure is reported.
     """
 
+    spec = db_spec(dbs)
+     
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name != dbs:
+            if not spec(config.db):
                 return fn(*args, **kw)
             else:
                 try:
                     fn(*args, **kw)
                 except Exception, ex:
                     print ("'%s' failed as expected on DB implementation "
-                           "'%s': %s" % (
-                        fn_name, config.db.name, reason))
+                            "'%s+%s': %s" % (
+                        fn_name, config.db.name, config.db.driver, reason))
                     return True
                 else:
                     raise AssertionError(
-                        "Unexpected success for '%s' on DB implementation '%s'" %
-                        (fn_name, config.db.name))
+                         "Unexpected success for '%s' on DB implementation '%s+%s'" %
+                         (fn_name, config.db.name, config.db.driver))
         return function_named(maybe, fn_name)
     return decorate
 
@@ -117,23 +134,25 @@ def fails_on_everything_except(*dbs):
     databases except those listed.
     """
 
+    spec = db_spec(*dbs)
+    
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name in dbs:
+            if spec(config.db):
                 return fn(*args, **kw)
             else:
                 try:
                     fn(*args, **kw)
                 except Exception, ex:
                     print ("'%s' failed as expected on DB implementation "
-                           "'%s': %s" % (
-                        fn_name, config.db.name, str(ex)))
+                            "'%s+%s': %s" % (
+                        fn_name, config.db.name, config.db.driver, str(ex)))
                     return True
                 else:
                     raise AssertionError(
-                        "Unexpected success for '%s' on DB implementation '%s'" %
-                        (fn_name, config.db.name))
+                      "Unexpected success for '%s' on DB implementation '%s+%s'" %
+                      (fn_name, config.db.name, config.db.driver))
         return function_named(maybe, fn_name)
     return decorate
 
@@ -145,12 +164,13 @@ def crashes(db, reason):
 
     """
     carp = _should_carp_about_exclusion(reason)
+    spec = db_spec(db)
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name == db:
-                msg = "'%s' unsupported on DB implementation '%s': %s" % (
-                    fn_name, config.db.name, reason)
+            if spec(config.db):
+                msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+                    fn_name, config.db.name, config.db.driver, reason)
                 print msg
                 if carp:
                     print >> sys.stderr, msg
@@ -169,12 +189,13 @@ def _block_unconditionally(db, reason):
 
     """
     carp = _should_carp_about_exclusion(reason)
+    spec = db_spec(db)
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
-            if config.db.name == db:
-                msg = "'%s' unsupported on DB implementation '%s': %s" % (
-                    fn_name, config.db.name, reason)
+            if spec(config.db):
+                msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+                    fn_name, config.db.name, config.db.driver, reason)
                 print msg
                 if carp:
                     print >> sys.stderr, msg
@@ -198,6 +219,7 @@ def exclude(db, op, spec, reason):
 
     """
     carp = _should_carp_about_exclusion(reason)
+    
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
@@ -242,7 +264,9 @@ def _is_excluded(db, op, spec):
       _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
     """
 
-    if config.db.name != db:
+    vendor_spec = db_spec(db)
+
+    if not vendor_spec(config.db):
         return False
 
     version = _server_version()
@@ -255,7 +279,12 @@ def _server_version(bind=None):
 
     if bind is None:
         bind = config.db
-    return bind.dialect.server_version_info(bind.contextual_connect())
+    
+    # force metadata to be retrieved
+    conn = bind.connect()
+    version = getattr(bind.dialect, 'server_version_info', ())
+    conn.close()
+    return version
 
 def skip_if(predicate, reason=None):
     """Skip a test if predicate is true."""
@@ -266,8 +295,7 @@ def skip_if(predicate, reason=None):
             if predicate():
                 msg = "'%s' skipped on DB %s version '%s': %s" % (
                     fn_name, config.db.name, _server_version(), reason)
-                print msg
-                return True
+                raise SkipTest(msg)
             else:
                 return fn(*args, **kw)
         return function_named(maybe, fn_name)
@@ -315,10 +343,12 @@ def emits_warning_on(db, *warnings):
     strings; these will be matched to the root of the warning description by
     warnings.filterwarnings().
     """
+    spec = db_spec(db)
+    
     def decorate(fn):
         def maybe(*args, **kw):
             if isinstance(db, basestring):
-                if config.db.name != db:
+                if not spec(config.db):
                     return fn(*args, **kw)
                 else:
                     wrapped = emits_warning(*warnings)(fn)
@@ -384,6 +414,19 @@ def resetwarnings():
     if sys.version_info < (2, 4):
         warnings.filterwarnings('ignore', category=FutureWarning)
 
+def global_cleanup_assertions():
+    """Check things that have to be finalized at the end of a test suite.
+    
+    Hardcoded at the moment, a modular system can be built here
+    to support things like PG prepared transactions, tables all
+    dropped, etc.
+    
+    """
+
+    testutil.lazy_gc()
+    assert not pool._refs
+    
+    
 
 def against(*queries):
     """Boolean predicate, compares to testing database configuration.
@@ -394,21 +437,20 @@ def against(*queries):
     Also supports comparison to database version when provided with one or
     more 3-tuples of dialect name, operator, and version specification::
 
-      testing.against('mysql', 'postgres')
+      testing.against('mysql', 'postgresql')
       testing.against(('mysql', '>=', (5, 0, 0))
     """
 
     for query in queries:
         if isinstance(query, basestring):
-            if config.db.name == query:
+            if db_spec(query)(config.db):
                 return True
         else:
             name, op, spec = query
-            if config.db.name != name:
+            if not db_spec(name)(config.db):
                 continue
 
-            have = config.db.dialect.server_version_info(
-                config.db.contextual_connect())
+            have = _server_version()
 
             oper = hasattr(op, '__call__') and op or _ops[op]
             if oper(have, spec):
@@ -545,16 +587,15 @@ class AssertsCompiledSQL(object):
         if dialect is None:
             dialect = getattr(self, '__dialect__', None)
 
-        if params is None:
-            keys = None
-        else:
-            keys = params.keys()
+        kw = {}
+        if params is not None:
+            kw['column_keys'] = params.keys()
 
-        c = clause.compile(column_keys=keys, dialect=dialect)
+        c = clause.compile(dialect=dialect, **kw)
 
-        print "\nSQL String:\n" + str(c) + repr(c.params)
+        print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
 
-        cc = re.sub(r'\n', '', str(c))
+        cc = re.sub(r'[\n\t]', '', str(c))
 
         eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
 
@@ -563,18 +604,13 @@ class AssertsCompiledSQL(object):
 
 class ComparesTables(object):
     def assert_tables_equal(self, table, reflected_table):
-        base_mro = sqltypes.TypeEngine.__mro__
         assert len(table.c) == len(reflected_table.c)
         for c, reflected_c in zip(table.c, reflected_table.c):
             eq_(c.name, reflected_c.name)
             assert reflected_c is reflected_table.c[c.name]
             eq_(c.primary_key, reflected_c.primary_key)
             eq_(c.nullable, reflected_c.nullable)
-            assert len(
-                set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
-                set(type(c.type).__mro__).difference(base_mro)
-                )
-            ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+            self.assert_types_base(reflected_c, c)
 
             if isinstance(c.type, sqltypes.String):
                 eq_(c.type.length, reflected_c.type.length)
@@ -586,14 +622,21 @@ class ComparesTables(object):
             elif against(('mysql', '<', (5, 0))):
                 # ignore reflection of bogus db-generated DefaultClause()
                 pass
-            elif not c.primary_key or not against('postgres'):
-                print repr(c)
+            elif not c.primary_key or not against('postgresql', 'mssql'):
+                #print repr(c)
                 assert reflected_c.default is None, reflected_c.default
 
         assert len(table.primary_key) == len(reflected_table.primary_key)
         for c in table.primary_key:
             assert reflected_table.primary_key.columns[c.name]
-
+    
+    def assert_types_base(self, c1, c2):
+        base_mro = sqltypes.TypeEngine.__mro__
+        assert len(
+            set(type(c1.type).__mro__).difference(base_mro).intersection(
+            set(type(c2.type).__mro__).difference(base_mro)
+            )
+        ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type)
 
 class AssertsExecutionResults(object):
     def assert_result(self, result, class_, *objects):
@@ -678,7 +721,7 @@ class AssertsExecutionResults(object):
             assertsql.asserter.clear_rules()
             
     def assert_sql(self, db, callable_, list_, with_sequences=None):
-        if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+        if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
             rules = with_sequences
         else:
             rules = list_
diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py
new file mode 100644 (file)
index 0000000..60b0a4e
--- /dev/null
@@ -0,0 +1,24 @@
+from sqlalchemy.util import jython, function_named
+
+import gc
+import time
+
+if jython:
+    def gc_collect(*args):
+        """aggressive gc.collect for tests."""
+        gc.collect()
+        time.sleep(0.1)
+        gc.collect()
+        gc.collect()
+        return 0
+        
+    # "lazy" gc, for VM's that don't GC on refcount == 0
+    lazy_gc = gc_collect
+
+else:
+    # assume CPython - straight gc.collect, lazy_gc() is a pass
+    gc_collect = gc.collect
+    def lazy_gc():
+        pass
+
+
index f9b9ad7b3673979a2baef542fb17bb62db0d0e54..fbdb17963b8b899e479ca2c0d178437db659348f 100644 (file)
@@ -161,20 +161,21 @@ def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False):
     edges = _EdgeCollection()
 
     for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]:
-        if id(item) not in nodes:
-            node = _Node(item)
-            nodes[item] = node
+        item_id = id(item)
+        if item_id not in nodes:
+            nodes[item_id] = _Node(item)
 
     for t in tuples:
+        id0, id1 = id(t[0]), id(t[1])
         if t[0] is t[1]:
             if allow_cycles:
-                n = nodes[t[0]]
+                n = nodes[id0]
                 n.cycles = set([n])
             elif not ignore_self_cycles:
                 raise CircularDependencyError("Self-referential dependency detected " + repr(t))
             continue
-        childnode = nodes[t[1]]
-        parentnode = nodes[t[0]]
+        childnode = nodes[id1]
+        parentnode = nodes[id0]
         edges.add((parentnode, childnode))
 
     queue = []
@@ -210,7 +211,7 @@ def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False):
         node = queue.pop()
         if not hasattr(node, '_cyclical'):
             output.append(node)
-        del nodes[node.item]
+        del nodes[id(node.item)]
         for childnode in edges.pop_node(node):
             queue.append(childnode)
     return output
@@ -293,8 +294,8 @@ def _find_cycles(edges):
     for parent in edges.get_parents():
         traverse(parent)
 
-    # sets are not hashable, so uniquify with id
-    unique_cycles = dict((id(s), s) for s in cycles.values()).values()
+    unique_cycles = set(tuple(s) for s in cycles.values())
+    
     for cycle in unique_cycles:
         edgecollection = [edge for edge in edges
                           if edge[0] in cycle and edge[1] in cycle]
index a03d6137dfa93d7204c0e43a4061c30f26a4d698..692e63347bb2f15879c102798868624376164c72 100644 (file)
@@ -11,11 +11,11 @@ types.
 For more information see the SQLAlchemy documentation on types.
 
 """
-__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType',
-            'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT',
-            'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB',
-            'BOOLEAN', 'SMALLINT', 'DATE', 'TIME',
-            'String', 'Integer', 'SmallInteger','Smallinteger',
+__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
+            'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', 'FLOAT',
+            'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'NCLOB', 'BLOB',
+            'BOOLEAN', 'SMALLINT', 'INTEGER','DATE', 'TIME',
+            'String', 'Integer', 'SmallInteger',
             'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary',
             'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval',
             'type_map'
@@ -24,17 +24,23 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType',
 import inspect
 import datetime as dt
 from decimal import Decimal as _python_Decimal
-import weakref
+
 from sqlalchemy import exc
 from sqlalchemy.util import pickle
+from sqlalchemy.sql.visitors import Visitable
 import sqlalchemy.util as util
 NoneType = type(None)
+if util.jython:
+    import array
     
-class AbstractType(object):
+class AbstractType(Visitable):
 
     def __init__(self, *args, **kwargs):
         pass
 
+    def compile(self, dialect):
+        return dialect.type_compiler.process(self)
+        
     def copy_value(self, value):
         return value
 
@@ -89,56 +95,23 @@ class AbstractType(object):
                       for k in inspect.getargspec(self.__init__)[0][1:]))
 
 class TypeEngine(AbstractType):
-    """Base for built-in types.
-
-    May be sub-classed to create entirely new types.  Example::
-
-      import sqlalchemy.types as types
-
-      class MyType(types.TypeEngine):
-          def __init__(self, precision = 8):
-              self.precision = precision
-
-          def get_col_spec(self):
-              return "MYTYPE(%s)" % self.precision
-
-          def bind_processor(self, dialect):
-              def process(value):
-                  return value
-              return process
-
-          def result_processor(self, dialect):
-              def process(value):
-                  return value
-              return process
-
-    Once the type is made, it's immediately usable::
-
-      table = Table('foo', meta,
-          Column('id', Integer, primary_key=True),
-          Column('data', MyType(16))
-          )
-
-    """
+    """Base for built-in types."""
 
+    @util.memoized_property
+    def _impl_dict(self):
+        return {}
+        
     def dialect_impl(self, dialect, **kwargs):
         try:
-            return self._impl_dict[dialect]
-        except AttributeError:
-            self._impl_dict = weakref.WeakKeyDictionary()   # will be optimized in 0.6
-            return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+            return self._impl_dict[dialect.__class__]
         except KeyError:
-            return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+            return self._impl_dict.setdefault(dialect.__class__, dialect.__class__.type_descriptor(self))
 
     def __getstate__(self):
         d = self.__dict__.copy()
         d.pop('_impl_dict', None)
         return d
 
-    def get_col_spec(self):
-        """Return the DDL representation for this type."""
-        raise NotImplementedError()
-
     def bind_processor(self, dialect):
         """Return a conversion function for processing bind values.
 
@@ -166,14 +139,42 @@ class TypeEngine(AbstractType):
     def adapt(self, cls):
         return cls()
 
-    def get_search_list(self):
-        """return a list of classes to test for a match
-        when adapting this type to a dialect-specific type.
+class UserDefinedType(TypeEngine):
+    """Base for user defined types.
+    
+    This should be the base of new types.  Note that
+    for most cases, :class:`TypeDecorator` is probably
+    more appropriate.
 
-        """
+      import sqlalchemy.types as types
 
-        return self.__class__.__mro__[0:-1]
+      class MyType(types.UserDefinedType):
+          def __init__(self, precision = 8):
+              self.precision = precision
 
+          def get_col_spec(self):
+              return "MYTYPE(%s)" % self.precision
+
+          def bind_processor(self, dialect):
+              def process(value):
+                  return value
+              return process
+
+          def result_processor(self, dialect):
+              def process(value):
+                  return value
+              return process
+
+    Once the type is made, it's immediately usable::
+
+      table = Table('foo', meta,
+          Column('id', Integer, primary_key=True),
+          Column('data', MyType(16))
+          )
+
+    """
+    __visit_name__ = "user_defined"
+    
 class TypeDecorator(AbstractType):
     """Allows the creation of types which add additional functionality
     to an existing type.
@@ -214,19 +215,33 @@ class TypeDecorator(AbstractType):
 
     """
 
+    __visit_name__ = "type_decorator"
+    
     def __init__(self, *args, **kwargs):
         if not hasattr(self.__class__, 'impl'):
-            raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
+            raise AssertionError("TypeDecorator implementations require a class-level "
+                        "variable 'impl' which refers to the class of type being decorated")
         self.impl = self.__class__.impl(*args, **kwargs)
 
-    def dialect_impl(self, dialect, **kwargs):
+    def dialect_impl(self, dialect):
         try:
-            return self._impl_dict[dialect]
+            return self._impl_dict[dialect.__class__]
         except AttributeError:
-            self._impl_dict = weakref.WeakKeyDictionary()   # will be optimized in 0.6
+            self._impl_dict = {}
         except KeyError:
             pass
 
+        # adapt the TypeDecorator first, in 
+        # the case that the dialect maps the TD
+        # to one of its native types (i.e. PGInterval)
+        adapted = dialect.__class__.type_descriptor(self)
+        if adapted is not self:
+            self._impl_dict[dialect] = adapted
+            return adapted
+        
+        # otherwise adapt the impl type, link
+        # to a copy of this TypeDecorator and return
+        # that.
         typedesc = self.load_dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
@@ -236,26 +251,30 @@ class TypeDecorator(AbstractType):
         self._impl_dict[dialect] = tt
         return tt
 
+    def type_engine(self, dialect):
+        impl = self.dialect_impl(dialect)
+        if not isinstance(impl, TypeDecorator):
+            return impl
+        else:
+            return impl.impl
+
     def load_dialect_impl(self, dialect):
         """Loads the dialect-specific implementation of this type.
 
         by default calls dialect.type_descriptor(self.impl), but
         can be overridden to provide different behavior.
+        
         """
-
         if isinstance(self.impl, TypeDecorator):
             return self.impl.dialect_impl(dialect)
         else:
-            return dialect.type_descriptor(self.impl)
+            return dialect.__class__.type_descriptor(self.impl)
 
     def __getattr__(self, key):
         """Proxy all other undefined accessors to the underlying implementation."""
 
         return getattr(self.impl, key)
 
-    def get_col_spec(self):
-        return self.impl.get_col_spec()
-
     def process_bind_param(self, value, dialect):
         raise NotImplementedError()
 
@@ -339,7 +358,7 @@ def to_instance(typeobj):
 def adapt_type(typeobj, colspecs):
     if isinstance(typeobj, type):
         typeobj = typeobj()
-    for t in typeobj.get_search_list():
+    for t in typeobj.__class__.__mro__[0:-1]:
         try:
             impltype = colspecs[t]
             break
@@ -370,9 +389,7 @@ class NullType(TypeEngine):
     encountered during a :meth:`~sqlalchemy.Table.create` operation.
 
     """
-
-    def get_col_spec(self):
-        raise NotImplementedError()
+    __visit_name__ = 'null'
 
 NullTypeEngine = NullType
 
@@ -400,6 +417,8 @@ class String(Concatenable, TypeEngine):
 
     """
 
+    __visit_name__ = 'string'
+    
     def __init__(self, length=None, convert_unicode=False, assert_unicode=None):
         """
         Create a string-holding type.
@@ -439,7 +458,10 @@ class String(Concatenable, TypeEngine):
         self.assert_unicode = assert_unicode
 
     def adapt(self, impltype):
-        return impltype(length=self.length, convert_unicode=self.convert_unicode, assert_unicode=self.assert_unicode)
+        return impltype(
+                    length=self.length, 
+                    convert_unicode=self.convert_unicode, 
+                    assert_unicode=self.assert_unicode)
 
     def bind_processor(self, dialect):
         if self.convert_unicode or dialect.convert_unicode:
@@ -447,18 +469,33 @@ class String(Concatenable, TypeEngine):
                 assert_unicode = dialect.assert_unicode
             else:
                 assert_unicode = self.assert_unicode
-            def process(value):
-                if isinstance(value, unicode):
-                    return value.encode(dialect.encoding)
-                elif assert_unicode and not isinstance(value, (unicode, NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
+            
+            if dialect.supports_unicode_binds and assert_unicode:
+                def process(value):
+                    if not isinstance(value, (unicode, NoneType)):
+                        if assert_unicode == 'warn':
+                            util.warn("Unicode type received non-unicode bind "
+                                      "param value %r" % value)
+                            return value
+                        else:
+                            raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+                    else:
                         return value
+            elif dialect.supports_unicode_binds:
+                return None
+            else:
+                def process(value):
+                    if isinstance(value, unicode):
+                        return value.encode(dialect.encoding)
+                    elif assert_unicode and not isinstance(value, (unicode, NoneType)):
+                        if assert_unicode == 'warn':
+                            util.warn("Unicode type received non-unicode bind "
+                                      "param value %r" % value)
+                            return value
+                        else:
+                            raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
                     else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
+                        return value
             return process
         else:
             return None
@@ -485,6 +522,7 @@ class Text(String):
     params (and the reverse for result sets.)
 
     """
+    __visit_name__ = 'text'
 
 class Unicode(String):
     """A variable length Unicode string.
@@ -511,10 +549,10 @@ class Unicode(String):
     :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which defaults
     to `utf-8`.
 
-    A synonym for String(length, convert_unicode=True, assert_unicode='warn').
-
     """
 
+    __visit_name__ = 'unicode'
+    
     def __init__(self, length=None, **kwargs):
         """
         Create a Unicode-converting String type.
@@ -532,7 +570,14 @@ class Unicode(String):
         super(Unicode, self).__init__(length=length, **kwargs)
 
 class UnicodeText(Text):
-    """A synonym for Text(convert_unicode=True, assert_unicode='warn')."""
+    """An unbounded-length Unicode string.
+    
+    See :class:`Unicode` for details on the unicode
+    behavior of this object.
+    
+    """
+
+    __visit_name__ = 'unicode_text'
 
     def __init__(self, length=None, **kwargs):
         """
@@ -553,7 +598,9 @@ class UnicodeText(Text):
 
 class Integer(TypeEngine):
     """A type for ``int`` integers."""
-
+    
+    __visit_name__ = 'integer'
+    
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
@@ -566,7 +613,17 @@ class SmallInteger(Integer):
 
     """
 
-Smallinteger = SmallInteger
+    __visit_name__ = 'small_integer'
+
+class BigInteger(Integer):
+    """A type for bigger ``int`` integers.
+
+    Typically generates a ``BIGINT`` in DDL, and otherwise acts like
+    a normal :class:`Integer` on the Python side.
+
+    """
+
+    __visit_name__ = 'big_integer'
 
 class Numeric(TypeEngine):
     """A type for fixed precision numbers.
@@ -576,7 +633,9 @@ class Numeric(TypeEngine):
 
     """
 
-    def __init__(self, precision=10, scale=2, asdecimal=True, length=None):
+    __visit_name__ = 'numeric'
+    
+    def __init__(self, precision=None, scale=None, asdecimal=True):
         """
         Construct a Numeric.
 
@@ -590,9 +649,6 @@ class Numeric(TypeEngine):
           use.
 
         """
-        if length:
-            util.warn_deprecated("'length' is deprecated for Numeric.  Use 'scale'.")
-            scale = length
         self.precision = precision
         self.scale = scale
         self.asdecimal = asdecimal
@@ -626,7 +682,9 @@ class Numeric(TypeEngine):
 class Float(Numeric):
     """A type for ``float`` numbers."""
 
-    def __init__(self, precision=10, asdecimal=False, **kwargs):
+    __visit_name__ = 'float'
+    
+    def __init__(self, precision=None, asdecimal=False, **kwargs):
         """
         Construct a Float.
 
@@ -650,7 +708,9 @@ class DateTime(TypeEngine):
     converted back to datetime objects when rows are returned.
 
     """
-
+    
+    __visit_name__ = 'datetime'
+    
     def __init__(self, timezone=False):
         self.timezone = timezone
 
@@ -664,6 +724,8 @@ class DateTime(TypeEngine):
 class Date(TypeEngine):
     """A type for ``datetime.date()`` objects."""
 
+    __visit_name__ = 'date'
+    
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
@@ -671,6 +733,8 @@ class Date(TypeEngine):
 class Time(TypeEngine):
     """A type for ``datetime.time()`` objects."""
 
+    __visit_name__ = 'time'
+
     def __init__(self, timezone=False):
         self.timezone = timezone
 
@@ -690,6 +754,8 @@ class Binary(TypeEngine):
 
     """
 
+    __visit_name__ = 'binary'
+
     def __init__(self, length=None):
         """
         Construct a Binary type.
@@ -775,7 +841,15 @@ class PickleType(MutableType, TypeDecorator):
         loads = self.pickler.loads
         if value is None:
             return None
-        return loads(str(value))
+        # Py3K
+        #return loads(value)
+        # Py2K
+        if util.jython and isinstance(value, array.ArrayType):
+            value = value.tostring()
+        else:
+            value = str(value)
+        return loads(value)
+        # end Py2K
 
     def copy_value(self, value):
         if self.mutable:
@@ -786,9 +860,6 @@ class PickleType(MutableType, TypeDecorator):
     def compare_values(self, x, y):
         if self.comparator:
             return self.comparator(x, y)
-        elif self.mutable and not hasattr(x, '__eq__') and x is not None:
-            util.warn_deprecated("Objects stored with PickleType when mutable=True must implement __eq__() for reliable comparison.")
-            return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol)
         else:
             return x == y
 
@@ -804,6 +875,7 @@ class Boolean(TypeEngine):
 
     """
 
+    __visit_name__ = 'boolean'
 
 class Interval(TypeDecorator):
     """A type for ``datetime.timedelta()`` objects.
@@ -815,106 +887,133 @@ class Interval(TypeDecorator):
 
     """
 
-    impl = TypeEngine
-
-    def __init__(self):
-        super(Interval, self).__init__()
-        import sqlalchemy.databases.postgres as pg
-        self.__supported = {pg.PGDialect:pg.PGInterval}
-        del pg
-
-    def load_dialect_impl(self, dialect):
-        if dialect.__class__ in self.__supported:
-            return self.__supported[dialect.__class__]()
-        else:
-            return dialect.type_descriptor(DateTime)
+    impl = DateTime
 
     def process_bind_param(self, value, dialect):
-        if dialect.__class__ in self.__supported:
-            return value
-        else:
-            if value is None:
-                return None
-            return dt.datetime.utcfromtimestamp(0) + value
+        if value is None:
+            return None
+        return dt.datetime.utcfromtimestamp(0) + value
 
     def process_result_value(self, value, dialect):
-        if dialect.__class__ in self.__supported:
-            return value
-        else:
-            if value is None:
-                return None
-            return value - dt.datetime.utcfromtimestamp(0)
+        if value is None:
+            return None
+        return value - dt.datetime.utcfromtimestamp(0)
 
 class FLOAT(Float):
     """The SQL FLOAT type."""
 
+    __visit_name__ = 'FLOAT'
 
 class NUMERIC(Numeric):
     """The SQL NUMERIC type."""
 
+    __visit_name__ = 'NUMERIC'
+
 
 class DECIMAL(Numeric):
     """The SQL DECIMAL type."""
 
+    __visit_name__ = 'DECIMAL'
 
-class INT(Integer):
+
+class INTEGER(Integer):
     """The SQL INT or INTEGER type."""
 
+    __visit_name__ = 'INTEGER'
+INT = INTEGER
 
-INTEGER = INT
 
-class SMALLINT(Smallinteger):
+class SMALLINT(SmallInteger):
     """The SQL SMALLINT type."""
 
+    __visit_name__ = 'SMALLINT'
+
+
+class BIGINT(BigInteger):
+    """The SQL BIGINT type."""
+
+    __visit_name__ = 'BIGINT'
 
 class TIMESTAMP(DateTime):
     """The SQL TIMESTAMP type."""
 
+    __visit_name__ = 'TIMESTAMP'
+
+    def get_dbapi_type(self, dbapi):
+        return dbapi.TIMESTAMP
 
 class DATETIME(DateTime):
     """The SQL DATETIME type."""
 
+    __visit_name__ = 'DATETIME'
+
 
 class DATE(Date):
     """The SQL DATE type."""
 
+    __visit_name__ = 'DATE'
+
 
 class TIME(Time):
     """The SQL TIME type."""
 
+    __visit_name__ = 'TIME'
 
-TEXT = Text
+class TEXT(Text):
+    """The SQL TEXT type."""
+    
+    __visit_name__ = 'TEXT'
 
 class CLOB(Text):
-    """The SQL CLOB type."""
+    """The CLOB type.
+    
+    This type is found in Oracle and Informix.
+    """
 
+    __visit_name__ = 'CLOB'
 
 class VARCHAR(String):
     """The SQL VARCHAR type."""
 
+    __visit_name__ = 'VARCHAR'
+
+class NVARCHAR(Unicode):
+    """The SQL NVARCHAR type."""
+
+    __visit_name__ = 'NVARCHAR'
 
 class CHAR(String):
     """The SQL CHAR type."""
 
+    __visit_name__ = 'CHAR'
+
 
 class NCHAR(Unicode):
     """The SQL NCHAR type."""
 
+    __visit_name__ = 'NCHAR'
+
 
 class BLOB(Binary):
     """The SQL BLOB type."""
 
+    __visit_name__ = 'BLOB'
+
 
 class BOOLEAN(Boolean):
     """The SQL BOOLEAN type."""
 
+    __visit_name__ = 'BOOLEAN'
+
 NULLTYPE = NullType()
 
 # using VARCHAR/NCHAR so that we dont get the genericized "String"
 # type which usually resolves to TEXT/CLOB
 type_map = {
-    str : VARCHAR,
-    unicode : NCHAR,
+    str: String,
+    # Py2K
+    unicode : String,
+    # end Py2K
     int : Integer,
     float : Numeric,
     bool: Boolean,
@@ -923,5 +1022,6 @@ type_map = {
     dt.datetime : DateTime,
     dt.time : Time,
     dt.timedelta : Interval,
-    type(None): NullType
+    NoneType: NullType
 }
+
index 8eeeda45559ce2bd378dde2a7869929e3e013bdf..f970f3737d9df43e2fd187db9722beed74076950 100644 (file)
@@ -4,8 +4,10 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import inspect, itertools, operator, sys, warnings, weakref
+import inspect, itertools, operator, sys, warnings, weakref, gc
+# Py2K
 import __builtin__
+# end Py2K
 types = __import__('types')
 
 from sqlalchemy import exc
@@ -16,6 +18,7 @@ except ImportError:
     import dummy_threading as threading
 
 py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
+jython = sys.platform.startswith('java')
 
 if py3k:
     set_types = set
@@ -38,6 +41,8 @@ else:
 
 EMPTY_SET = frozenset()
 
+NoneType = type(None)
+
 if py3k:
     import pickle
 else:
@@ -46,11 +51,13 @@ else:
     except ImportError:
         import pickle
 
+# Py2K
 # a controversial feature, required by MySQLdb currently
 def buffer(x):
     return x 
     
 buffer = getattr(__builtin__, 'buffer', buffer)
+# end Py2K
         
 if sys.version_info >= (2, 5):
     class PopulateDict(dict):
@@ -84,12 +91,13 @@ else:
 if py3k:
     def callable(fn):
         return hasattr(fn, '__call__')
-else:
-    callable = __builtin__.callable
+    def cmp(a, b):
+        return (a > b) - (a < b)
 
-if py3k:
     from functools import reduce
 else:
+    callable = __builtin__.callable
+    cmp = __builtin__.cmp
     reduce = __builtin__.reduce
 
 try:
@@ -262,6 +270,15 @@ else:
     def decode_slice(slc):
         return (slc.start, slc.stop, slc.step)
 
+def update_copy(d, _new=None, **kw):
+    """Copy the given dict and update with the given values."""
+    
+    d = d.copy()
+    if _new:
+        d.update(_new)
+    d.update(**kw)
+    return d
+    
 def flatten_iterator(x):
     """Given an iterator of which further sub-elements may also be
     iterators, flatten the sub-elements into a single iterator.
@@ -296,6 +313,7 @@ def get_cls_kwargs(cls):
         class_ = stack.pop()
         ctr = class_.__dict__.get('__init__', False)
         if not ctr or not isinstance(ctr, types.FunctionType):
+            stack.update(class_.__bases__)
             continue
         names, _, has_kw, _ = inspect.getargspec(ctr)
         args.update(names)
@@ -419,20 +437,32 @@ def class_hierarchy(cls):
     will not be descended.
 
     """
+    # Py2K
     if isinstance(cls, types.ClassType):
         return list()
+    # end Py2K
     hier = set([cls])
     process = list(cls.__mro__)
     while process:
         c = process.pop()
+        # Py2K
         if isinstance(c, types.ClassType):
             continue
         for b in (_ for _ in c.__bases__
                   if _ not in hier and not isinstance(_, types.ClassType)):
+        # end Py2K
+        # Py3K
+        #for b in (_ for _ in c.__bases__
+        #          if _ not in hier):
             process.append(b)
             hier.add(b)
+        # Py3K
+        #if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'):
+        #    continue
+        # Py2K
         if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'):
             continue
+        # end Py2K
         for s in [_ for _ in c.__subclasses__() if _ not in hier]:
             process.append(s)
             hier.add(s)
@@ -664,12 +694,11 @@ class OrderedProperties(object):
         return self._data.keys()
 
     def has_key(self, key):
-        return self._data.has_key(key)
+        return key in self._data
 
     def clear(self):
         self._data.clear()
 
-
 class OrderedDict(dict):
     """A dict that returns keys/values/items in the order they were added."""
 
@@ -735,7 +764,12 @@ class OrderedDict(dict):
 
     def __setitem__(self, key, object):
         if key not in self:
-            self._list.append(key)
+            try:
+                self._list.append(key)
+            except AttributeError:
+                # work around Python pickle loads() with 
+                # dict subclass (seems to ignore __setstate__?)
+                self._list = [key]
         dict.__setitem__(self, key, object)
 
     def __delitem__(self, key):
@@ -915,7 +949,7 @@ class IdentitySet(object):
 
         if len(self) > len(other):
             return False
-        for m in itertools.ifilterfalse(other._members.has_key,
+        for m in itertools.ifilterfalse(other._members.__contains__,
                                         self._members.iterkeys()):
             return False
         return True
@@ -1416,8 +1450,11 @@ class WeakIdentityMapping(weakref.WeakKeyDictionary):
         return item
 
     def clear(self):
+        # Py2K
+        # in 3k, MutableMapping calls popitem()
         self._weakrefs.clear()
         self.by_id.clear()
+        # end Py2K
         weakref.WeakKeyDictionary.clear(self)
 
     def update(self, *a, **kw):
diff --git a/sa2to3.py b/sa2to3.py
new file mode 100644 (file)
index 0000000..9c06daf
--- /dev/null
+++ b/sa2to3.py
@@ -0,0 +1,76 @@
+"""SQLAlchemy 2to3 tool.
+
+Relax !  This just calls the regular 2to3 tool with a preprocessor bolted onto it.
+
+
+I originally wanted to write a custom fixer to accomplish this
+but the Fixer classes seem like they can only understand 
+the grammar file included with 2to3, and the grammar does not
+seem to include Python comments (and of course, huge hacks needed
+to get out-of-package fixers in there).   While that may be
+an option later on this is a pretty simple approach for
+what is a pretty simple problem.
+
+"""
+
+from lib2to3 import main, refactor
+
+import re
+
+py3k_pattern = re.compile(r'\s*# Py3K')
+comment_pattern = re.compile(r'(\s*)#(?! ?Py2K)(.*)')
+py2k_pattern = re.compile(r'\s*# Py2K')
+end_py2k_pattern = re.compile(r'\s*# end Py2K')
+
+def preprocess(data):
+    lines = data.split('\n')
+    def consume_normal():
+        while lines:
+            line = lines.pop(0)
+            if py3k_pattern.match(line):
+                for line in consume_py3k():
+                    yield line
+            elif py2k_pattern.match(line):
+                for line in consume_py2k():
+                    yield line
+            else:
+                yield line
+    
+    def consume_py3k():
+        yield "# start Py3K"
+        while lines:
+            line = lines.pop(0)
+            m = comment_pattern.match(line)
+            if m:
+                yield "%s%s" % m.group(1, 2)
+            else:
+                # pushback
+                lines.insert(0, line)
+                break
+        yield "# end Py3K"
+    
+    def consume_py2k():
+        yield "# start Py2K"
+        while lines:
+            line = lines.pop(0)
+            if not end_py2k_pattern.match(line):
+                yield "#%s" % line
+            else:
+                break
+        yield "# end Py2K"
+
+    return "\n".join(consume_normal())
+
+old_refactor_string = main.StdoutRefactoringTool.refactor_string
+
+def refactor_string(self, data, name):
+    newdata = preprocess(data)
+    tree = old_refactor_string(self, newdata, name)
+    if tree:
+        if newdata != data:
+            tree.was_changed = True
+    return tree
+    
+main.StdoutRefactoringTool.refactor_string = refactor_string
+
+main.main("lib2to3.fixes")
index 3d65f022e0694778e369a875aff5e61e47bb2ee7..12925a11544b2c2a683af44f812b8cacee0e4877 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -18,7 +18,7 @@ def find_packages(dir_):
 if sys.version_info < (2, 4):
     raise Exception("SQLAlchemy requires Python 2.4 or higher.")
 
-v = file(os.path.join(os.path.dirname(__file__), 'lib', 'sqlalchemy', '__init__.py'))
+v = open(os.path.join(os.path.dirname(__file__), 'lib', 'sqlalchemy', '__init__.py'))
 VERSION = re.compile(r".*__version__ = '(.*?)'", re.S).match(v.read()).group(1)
 v.close()
 
@@ -31,7 +31,8 @@ setup(name = "SQLAlchemy",
       packages = find_packages('lib'),
       package_dir = {'':'lib'},
       license = "MIT License",
-
+      tests_require = ['nose >= 0.10'],
+      test_suite = "nose.collector",
       entry_points = {
           'nose.plugins.0.10': [
               'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy',
diff --git a/sqla_nose.py b/sqla_nose.py
new file mode 100644 (file)
index 0000000..0542b4e
--- /dev/null
@@ -0,0 +1,22 @@
+"""
+nose runner script.
+
+Only use this script if setuptools is not available, i.e. such as
+on Python 3K.  Otherwise consult README.unittests for the 
+recommended methods of running tests.
+
+"""
+
+import nose
+from sqlalchemy.test.noseplugin import NoseSQLAlchemy
+from sqlalchemy.util import py3k
+
+if __name__ == '__main__':
+    if py3k:
+        # this version breaks verbose output, 
+        # but is the only API that nose3 currently supports
+        nose.main(plugins=[NoseSQLAlchemy()])
+    else:
+        # this is the "correct" API
+        nose.main(addplugins=[NoseSQLAlchemy()])
+
index 3e4274d47dd3c686992d71ed4560f728168c4b47..79ae09b0544b366702dc38845cf66d72ac9384f3 100644 (file)
@@ -15,15 +15,15 @@ class CompileTest(TestBase, AssertsExecutionResults):
             Column('c1', Integer, primary_key=True),
             Column('c2', String(30)))
 
-    @profiling.function_call_count(68, {'2.4': 42})
+    @profiling.function_call_count(72, {'2.4': 42, '3.0':77})
     def test_insert(self):
         t1.insert().compile()
 
-    @profiling.function_call_count(68, {'2.4': 45})
+    @profiling.function_call_count(72, {'2.4': 45})
     def test_update(self):
         t1.update().compile()
 
-    @profiling.function_call_count(185, versions={'2.4':118})
+    @profiling.function_call_count(195, versions={'2.4':118, '3.0':208})
     def test_select(self):
         s = select([t1], t1.c.c2==t2.c.c1)
         s.compile()
index 70a3cf8cd68ffbc4d5e86c8351ba43f307874e6c..fbf0560ca1b079f7c08e74a675086e72dd4fed79 100644 (file)
@@ -1,17 +1,22 @@
 from sqlalchemy.test.testing import eq_
-import gc
 from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker
 from sqlalchemy.orm.mapper import _mapper_registry
 from sqlalchemy.orm.session import _sessions
+from sqlalchemy.util import jython
 import operator
 from sqlalchemy.test import testing
 from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 import sqlalchemy as sa
 from sqlalchemy.sql import column
+from sqlalchemy.test.util import gc_collect
+import gc
 from test.orm import _base
 
+if jython:
+    from nose import SkipTest
+    raise SkipTest("Profiling not supported on this platform")
+
 
 class A(_base.ComparableEntity):
     pass
@@ -22,11 +27,11 @@ def profile_memory(func):
     # run the test 50 times.  if length of gc.get_objects()
     # keeps growing, assert false
     def profile(*args):
-        gc.collect()
+        gc_collect()
         samples = [0 for x in range(0, 50)]
         for x in range(0, 50):
             func(*args)
-            gc.collect()
+            gc_collect()
             samples[x] = len(gc.get_objects())
         print "sample gc sizes:", samples
 
@@ -50,7 +55,7 @@ def profile_memory(func):
 
 def assert_no_mappers():
     clear_mappers()
-    gc.collect()
+    gc_collect()
     assert len(_mapper_registry) == 0
 
 class EnsureZeroed(_base.ORMTest):
@@ -61,7 +66,7 @@ class EnsureZeroed(_base.ORMTest):
 class MemUsageTest(EnsureZeroed):
     
     # ensure a pure growing test trips the assertion
-    @testing.fails_if(lambda:True)
+    @testing.fails_if(lambda: True)
     def test_fixture(self):
         class Foo(object):
             pass
@@ -76,11 +81,11 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
@@ -129,11 +134,11 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
@@ -184,13 +189,13 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30))
             )
 
         table2 = Table("mytable2", metadata,
             Column('col1', Integer, ForeignKey('mytable.col1'),
-                   primary_key=True),
+                   primary_key=True, test_needs_autoincrement=True),
             Column('col3', String(30)),
             )
 
@@ -244,12 +249,12 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30))
             )
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)),
             )
 
@@ -308,12 +313,12 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("table1", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30))
             )
 
         table2 = Table("table2", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('t1id', Integer, ForeignKey('table1.id'))
             )
@@ -347,7 +352,7 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', PickleType(comparator=operator.eq))
             )
         
@@ -382,7 +387,7 @@ class MemUsageTest(EnsureZeroed):
             testing.eq_(len(session.identity_map._mutable_attrs), 12)
             testing.eq_(len(session.identity_map), 12)
             obj = None
-            gc.collect()
+            gc_collect()
             testing.eq_(len(session.identity_map._mutable_attrs), 0)
             testing.eq_(len(session.identity_map), 0)
             
@@ -392,7 +397,7 @@ class MemUsageTest(EnsureZeroed):
             metadata.drop_all()
 
     def test_type_compile(self):
-        from sqlalchemy.databases.sqlite import SQLiteDialect
+        from sqlalchemy.dialects.sqlite.base import dialect as SQLiteDialect
         cast = sa.cast(column('x'), sa.Integer)
         @profile_memory
         def go():
index 7bb61deb28d0151c88b0e659029df4afe9c5685d..6ae3edc989ddb33e9b93dae2acfe150f186f51ac 100644 (file)
@@ -5,6 +5,9 @@ from sqlalchemy.pool import QueuePool
 
 class QueuePoolTest(TestBase, AssertsExecutionResults):
     class Connection(object):
+        def rollback(self):
+            pass
+            
         def close(self):
             pass
 
@@ -15,7 +18,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
                          use_threadlocal=True)
 
 
-    @profiling.function_call_count(54, {'2.4': 38})
+    @profiling.function_call_count(54, {'2.4': 38, '3.0':57})
     def test_first_connect(self):
         conn = pool.connect()
 
@@ -23,7 +26,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
         conn = pool.connect()
         conn.close()
 
-        @profiling.function_call_count(31, {'2.4': 21})
+        @profiling.function_call_count(29, {'2.4': 21})
         def go():
             conn2 = pool.connect()
             return conn2
@@ -32,7 +35,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
     def test_second_samethread_connect(self):
         conn = pool.connect()
 
-        @profiling.function_call_count(5, {'2.4': 3})
+        @profiling.function_call_count(5, {'2.4': 3, '3.0':6})
         def go():
             return pool.connect()
         c2 = go()
index be29318964a8613bce44d8e72ffb6246b9f696a6..e413031926bd78fa85a07cf15f54ba08dfdda528 100644 (file)
@@ -26,7 +26,7 @@ class ZooMarkTest(TestBase):
 
     """
 
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql+psycopg2'
     __skip_if__ = ((lambda: sys.version_info < (2, 4)), )
 
     def test_baseline_0_setup(self):
@@ -75,15 +75,15 @@ class ZooMarkTest(TestBase):
                            Opens=datetime.time(8, 15, 59),
                            LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7),
                            Admission=4.95,
-                           ).last_inserted_ids()[0]
+                           ).inserted_primary_key[0]
 
         sdz = Zoo.insert().execute(Name =u'San Diego Zoo',
                            Founded = datetime.date(1935, 9, 13),
                            Opens = datetime.time(9, 0, 0),
                            Admission = 0,
-                           ).last_inserted_ids()[0]
+                           ).inserted_primary_key[0]
 
-        Zoo.insert().execute(
+        Zoo.insert(inline=True).execute(
                   Name = u'Montr\xe9al Biod\xf4me',
                   Founded = datetime.date(1992, 6, 19),
                   Opens = datetime.time(9, 0, 0),
@@ -91,48 +91,48 @@ class ZooMarkTest(TestBase):
                   )
 
         seaworld = Zoo.insert().execute(
-                Name =u'Sea_World', Admission = 60).last_inserted_ids()[0]
+                Name =u'Sea_World', Admission = 60).inserted_primary_key[0]
 
         # Let's add a crazy futuristic Zoo to test large date values.
         lp = Zoo.insert().execute(Name =u'Luna Park',
                                   Founded = datetime.date(2072, 7, 17),
                                   Opens = datetime.time(0, 0, 0),
                                   Admission = 134.95,
-                                  ).last_inserted_ids()[0]
+                                  ).inserted_primary_key[0]
 
         # Animals
         leopardid = Animal.insert().execute(Species=u'Leopard', Lifespan=73.5,
-                                            ).last_inserted_ids()[0]
+                                            ).inserted_primary_key[0]
         Animal.update(Animal.c.ID==leopardid).execute(ZooID=wap,
                 LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907))
 
-        lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).last_inserted_ids()[0]
+        lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).inserted_primary_key[0]
         Animal.insert().execute(Species=u'Slug', Legs=1, Lifespan=.75)
 
         tiger = Animal.insert().execute(Species=u'Tiger', ZooID=sdz
-                                        ).last_inserted_ids()[0]
+                                        ).inserted_primary_key[0]
 
         # Override Legs.default with itself just to make sure it works.
-        Animal.insert().execute(Species=u'Bear', Legs=4)
-        Animal.insert().execute(Species=u'Ostrich', Legs=2, Lifespan=103.2)
-        Animal.insert().execute(Species=u'Centipede', Legs=100)
+        Animal.insert(inline=True).execute(Species=u'Bear', Legs=4)
+        Animal.insert(inline=True).execute(Species=u'Ostrich', Legs=2, Lifespan=103.2)
+        Animal.insert(inline=True).execute(Species=u'Centipede', Legs=100)
 
         emp = Animal.insert().execute(Species=u'Emperor Penguin', Legs=2,
-                                      ZooID=seaworld).last_inserted_ids()[0]
+                                      ZooID=seaworld).inserted_primary_key[0]
         adelie = Animal.insert().execute(Species=u'Adelie Penguin', Legs=2,
-                                         ZooID=seaworld).last_inserted_ids()[0]
+                                         ZooID=seaworld).inserted_primary_key[0]
 
-        Animal.insert().execute(Species=u'Millipede', Legs=1000000, ZooID=sdz)
+        Animal.insert(inline=True).execute(Species=u'Millipede', Legs=1000000, ZooID=sdz)
 
         # Add a mother and child to test relationships
         bai_yun = Animal.insert().execute(Species=u'Ape', Name=u'Bai Yun',
-                                          Legs=2).last_inserted_ids()[0]
-        Animal.insert().execute(Species=u'Ape', Name=u'Hua Mei', Legs=2,
+                                          Legs=2).inserted_primary_key[0]
+        Animal.insert(inline=True).execute(Species=u'Ape', Name=u'Hua Mei', Legs=2,
                                 MotherID=bai_yun)
 
     def test_baseline_2_insert(self):
         Animal = metadata.tables['Animal']
-        i = Animal.insert()
+        i = Animal.insert(inline=True)
         for x in xrange(ITERATIONS):
             tick = i.execute(Species=u'Tick', Name=u'Tick %d' % x, Legs=8)
 
@@ -142,7 +142,7 @@ class ZooMarkTest(TestBase):
 
         def fullobject(select):
             """Iterate over the full result row."""
-            return list(select.execute().fetchone())
+            return list(select.execute().first())
 
         for x in xrange(ITERATIONS):
             # Zoos
@@ -254,7 +254,7 @@ class ZooMarkTest(TestBase):
 
         for x in xrange(ITERATIONS):
             # Edit
-            SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone()
+            SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first()
             Zoo.update(Zoo.c.ID==SDZ['ID']).execute(
                      Name=u'The San Diego Zoo',
                      Founded = datetime.date(1900, 1, 1),
@@ -262,7 +262,7 @@ class ZooMarkTest(TestBase):
                      Admission = "35.00")
 
             # Test edits
-            SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().fetchone()
+            SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().first()
             assert SDZ['Founded'] == datetime.date(1900, 1, 1), SDZ['Founded']
 
             # Change it back
@@ -273,7 +273,7 @@ class ZooMarkTest(TestBase):
                      Admission = "0")
 
             # Test re-edits
-            SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone()
+            SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first()
             assert SDZ['Founded'] == datetime.date(1935, 9, 13)
 
     def test_baseline_7_multiview(self):
@@ -316,10 +316,10 @@ class ZooMarkTest(TestBase):
         global metadata
 
         player = lambda: dbapi_session.player()
-        engine = create_engine('postgres:///', creator=player)
+        engine = create_engine('postgresql:///', creator=player)
         metadata = MetaData(engine)
 
-    @profiling.function_call_count(3230, {'2.4': 1796})
+    @profiling.function_call_count(2991, {'2.4': 1796})
     def test_profile_1_create_tables(self):
         self.test_baseline_1_create_tables()
 
@@ -327,7 +327,7 @@ class ZooMarkTest(TestBase):
     def test_profile_1a_populate(self):
         self.test_baseline_1a_populate()
 
-    @profiling.function_call_count(322, {'2.4': 202})
+    @profiling.function_call_count(305, {'2.4': 202})
     def test_profile_2_insert(self):
         self.test_baseline_2_insert()
 
index 57e1e24049c37e2efa92a80a63049324d0213ca8..660f47811036e5ef7bc153e83fb738b90e1abab7 100644 (file)
@@ -27,7 +27,7 @@ class ZooMarkTest(TestBase):
 
     """
 
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql+psycopg2'
     __skip_if__ = ((lambda: sys.version_info < (2, 5)), )  # TODO: get 2.4 support
 
     def test_baseline_0_setup(self):
@@ -281,7 +281,7 @@ class ZooMarkTest(TestBase):
         global metadata, session
 
         player = lambda: dbapi_session.player()
-        engine = create_engine('postgres:///', creator=player)
+        engine = create_engine('postgresql:///', creator=player)
         metadata = MetaData(engine)
         session = sessionmaker()()
 
index 0457d552a4eef7d22f6535b1aea38e9b1d3a2053..890dd76078f533a3de8eededc9c80c820d5e0a2c 100644 (file)
@@ -178,9 +178,12 @@ class DependencySortTest(TestBase):
         self.assert_sort(tuples, head)
 
     def testbigsort(self):
-        tuples = []
-        for i in range(0,1500, 2):
-            tuples.append((i, i+1))
+        tuples = [(i, i + 1) for i in range(0, 1500, 2)]
         head = topological.sort_as_tree(tuples, [])
 
 
+    def testids(self):
+        # ticket:1380 regression: would raise a KeyError
+        topological.sort([(id(i), i) for i in range(3)], [])
+
+
index efb18a153c980d1d862e01c74f3e8eef3bb46c6e..fbe0a05de422bac015ea907022ed0497ac106ff7 100644 (file)
@@ -1,10 +1,14 @@
 """Tests exceptions and DB-API exception wrapping."""
-import exceptions as stdlib_exceptions
 from sqlalchemy import exc as sa_exceptions
 from sqlalchemy.test import TestBase
 
+# Py3K
+#StandardError = BaseException
+# Py2K
+from exceptions import StandardError, KeyboardInterrupt, SystemExit
+# end Py2K
 
-class Error(stdlib_exceptions.StandardError):
+class Error(StandardError):
     """This class will be old-style on <= 2.4 and new-style on >= 2.5."""
 class DatabaseError(Error):
     pass
@@ -101,19 +105,19 @@ class WrapTest(TestBase):
     def test_db_error_keyboard_interrupt(self):
         try:
             raise sa_exceptions.DBAPIError.instance(
-                '', [], stdlib_exceptions.KeyboardInterrupt())
+                '', [], KeyboardInterrupt())
         except sa_exceptions.DBAPIError:
             self.assert_(False)
-        except stdlib_exceptions.KeyboardInterrupt:
+        except KeyboardInterrupt:
             self.assert_(True)
 
     def test_db_error_system_exit(self):
         try:
             raise sa_exceptions.DBAPIError.instance(
-                '', [], stdlib_exceptions.SystemExit())
+                '', [], SystemExit())
         except sa_exceptions.DBAPIError:
             self.assert_(False)
-        except stdlib_exceptions.SystemExit:
+        except SystemExit:
             self.assert_(True)
 
 
index 39561e9682eae3d3b9b07fd77871bc2133f1a76d..e4c2eaba059309c5752067e1fda39b567ed9d2ac 100644 (file)
@@ -3,6 +3,7 @@ import copy, threading
 from sqlalchemy import util, sql, exc
 from sqlalchemy.test import TestBase
 from sqlalchemy.test.testing import eq_, is_, ne_
+from sqlalchemy.test.util import gc_collect
 
 class OrderedDictTest(TestBase):
     def test_odict(self):
@@ -260,7 +261,7 @@ class IdentitySetTest(TestBase):
         except TypeError:
             assert True
 
-        assert_raises(TypeError, cmp, ids)
+        assert_raises(TypeError, util.cmp, ids)
         assert_raises(TypeError, hash, ids)
 
     def test_difference(self):
@@ -325,11 +326,13 @@ class DictlikeIteritemsTest(TestBase):
         d = subdict(a=1,b=2,c=3)
         self._ok(d)
 
+    # Py2K
     def test_UserDict(self):
         import UserDict
         d = UserDict.UserDict(a=1,b=2,c=3)
         self._ok(d)
-
+    # end Py2K
+    
     def test_object(self):
         self._notok(object())
 
@@ -339,12 +342,15 @@ class DictlikeIteritemsTest(TestBase):
                 return iter(self.baseline)
         self._ok(duck1())
 
+    # Py2K
     def test_duck_2(self):
         class duck2(object):
             def items(duck):
                 return list(self.baseline)
         self._ok(duck2())
+    # end Py2K
 
+    # Py2K
     def test_duck_3(self):
         class duck3(object):
             def iterkeys(duck):
@@ -352,6 +358,7 @@ class DictlikeIteritemsTest(TestBase):
             def __getitem__(duck, key):
                 return dict(a=1,b=2,c=3).get(key)
         self._ok(duck3())
+    # end Py2K
 
     def test_duck_4(self):
         class duck4(object):
@@ -376,16 +383,20 @@ class DictlikeIteritemsTest(TestBase):
 
 class DuckTypeCollectionTest(TestBase):
     def test_sets(self):
+        # Py2K
         import sets
+        # end Py2K
         class SetLike(object):
             def add(self):
                 pass
 
         class ForcedSet(list):
             __emulates__ = set
-
+        
         for type_ in (set,
+                      # Py2K
                       sets.Set,
+                      # end Py2K
                       SetLike,
                       ForcedSet):
             eq_(util.duck_type_collection(type_), set)
@@ -393,12 +404,14 @@ class DuckTypeCollectionTest(TestBase):
             eq_(util.duck_type_collection(instance), set)
 
         for type_ in (frozenset,
-                      sets.ImmutableSet):
+                      # Py2K
+                      sets.ImmutableSet
+                      # end Py2K
+                      ):
             is_(util.duck_type_collection(type_), None)
             instance = type_()
             is_(util.duck_type_collection(instance), None)
 
-
 class ArgInspectionTest(TestBase):
     def test_get_cls_kwargs(self):
         class A(object):
@@ -646,6 +659,8 @@ class WeakIdentityMappingTest(TestBase):
         assert len(data) == len(wim) == len(wim.by_id)
 
         del data[:]
+        gc_collect()
+
         eq_(wim, {})
         eq_(wim.by_id, {})
         eq_(wim._weakrefs, {})
@@ -657,6 +672,7 @@ class WeakIdentityMappingTest(TestBase):
 
         oid = id(data[0])
         del data[0]
+        gc_collect()
 
         assert len(data) == len(wim) == len(wim.by_id)
         assert oid not in wim.by_id
@@ -679,6 +695,7 @@ class WeakIdentityMappingTest(TestBase):
         th.start()
         cv.wait()
         cv.release()
+        gc_collect()
 
         eq_(wim, {})
         eq_(wim.by_id, {})
@@ -939,7 +956,8 @@ class TestClassHierarchy(TestBase):
 
         eq_(set(util.class_hierarchy(A)), set((A, B, C, object)))
         eq_(set(util.class_hierarchy(B)), set((A, B, C, object)))
-
+    
+    # Py2K
     def test_oldstyle_mixin(self):
         class A(object):
             pass
@@ -953,5 +971,5 @@ class TestClassHierarchy(TestBase):
         eq_(set(util.class_hierarchy(B)), set((A, B, object)))
         eq_(set(util.class_hierarchy(Mixin)), set())
         eq_(set(util.class_hierarchy(A)), set((A, B, object)))
-
+    # end Py2K
         
index fa608c9a18e5c0761cf582b352ab38dc3b780b22..2dc6af91b76ae7c4c8246e63f6632b845a0fbcd1 100644 (file)
@@ -50,6 +50,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         con.execute('DROP GENERATOR gen_testtable_id')
 
     def test_table_is_reflected(self):
+        from sqlalchemy.types import Integer, Text, Binary, String, Date, Time, DateTime
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True)
         eq_(set(table.columns.keys()),
@@ -57,17 +58,17 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
                           "Columns of reflected table didn't equal expected columns")
         eq_(table.c.question.primary_key, True)
         eq_(table.c.question.sequence.name, 'gen_testtable_id')
-        eq_(table.c.question.type.__class__, firebird.FBInteger)
+        assert isinstance(table.c.question.type, Integer)
         eq_(table.c.question.server_default.arg.text, "42")
-        eq_(table.c.answer.type.__class__, firebird.FBString)
+        assert isinstance(table.c.answer.type, String)
         eq_(table.c.answer.server_default.arg.text, "'no answer'")
-        eq_(table.c.remark.type.__class__, firebird.FBText)
+        assert isinstance(table.c.remark.type, Text)
         eq_(table.c.remark.server_default.arg.text, "''")
-        eq_(table.c.photo.type.__class__, firebird.FBBinary)
+        assert isinstance(table.c.photo.type, Binary)
         # The following assume a Dialect 3 database
-        eq_(table.c.d.type.__class__, firebird.FBDate)
-        eq_(table.c.t.type.__class__, firebird.FBTime)
-        eq_(table.c.dt.type.__class__, firebird.FBDateTime)
+        assert isinstance(table.c.d.type, Date)
+        assert isinstance(table.c.t.type, Time)
+        assert isinstance(table.c.dt.type, DateTime)
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
@@ -76,7 +77,13 @@ class CompileTest(TestBase, AssertsCompiledSQL):
     def test_alias(self):
         t = table('sometable', column('col1'), column('col2'))
         s = select([t.alias()])
-        self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1")
+        self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable AS sometable_1")
+
+        dialect = firebird.FBDialect()
+        dialect._version_two = False
+        self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1",
+            dialect = dialect
+        )
 
     def test_function(self):
         self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)")
@@ -98,15 +105,15 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('description', String(128)),
         )
 
-        u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+        u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name")
 
-        u = update(table1, values=dict(name='foo'), firebird_returning=[table1])
+        u = update(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(u, "UPDATE mytable SET name=:name "\
             "RETURNING mytable.myid, mytable.name, mytable.description")
 
-        u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
-        self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
+        u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1")
 
     def test_insert_returning(self):
         table1 = table('mytable',
@@ -115,90 +122,20 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('description', String(128)),
         )
 
-        i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+        i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name")
 
-        i = insert(table1, values=dict(name='foo'), firebird_returning=[table1])
+        i = insert(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\
             "RETURNING mytable.myid, mytable.name, mytable.description")
 
-        i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
-        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
+        i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1")
 
 
-class ReturningTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'firebird'
-
-    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
-    def test_update_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
-            result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
-            eq_(result.fetchall(), [(1,)])
-
-            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            eq_(result2.fetchall(), [(1,True),(2,False)])
-        finally:
-            table.drop()
-
-    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
-    def test_insert_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
-            eq_(result.fetchall(), [(1,)])
-
-            # Multiple inserts only return the last row
-            result2 = table.insert(firebird_returning=[table]).execute(
-                 [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-
-            eq_(result2.fetchall(), [(3,3,True)])
-
-            result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
-            eq_([dict(row) for row in result3], [{'ID':4}])
-
-            result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
-            eq_([dict(row) for row in result4], [{'PERSONS': 10}])
-        finally:
-            table.drop()
-
-    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
-    def test_delete_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
-            result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute()
-            eq_(result.fetchall(), [(1,)])
-
-            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            eq_(result2.fetchall(), [(2,False),])
-        finally:
-            table.drop()
 
 
-class MiscFBTests(TestBase):
+class MiscTest(TestBase):
     __only_on__ = 'firebird'
 
     def test_strlen(self):
@@ -217,12 +154,20 @@ class MiscFBTests(TestBase):
         try:
             t.insert(values=dict(name='dante')).execute()
             t.insert(values=dict(name='alighieri')).execute()
-            select([func.count(t.c.id)],func.length(t.c.name)==5).execute().fetchone()[0] == 1
+            select([func.count(t.c.id)],func.length(t.c.name)==5).execute().first()[0] == 1
         finally:
             meta.drop_all()
 
     def test_server_version_info(self):
-        version = testing.db.dialect.server_version_info(testing.db.connect())
+        version = testing.db.dialect.server_version_info
         assert len(version) == 3, "Got strange version info: %s" % repr(version)
 
+    def test_percents_in_text(self):
+        for expr, result in (
+            (text("select '%' from rdb$database"), '%'),
+            (text("select '%%' from rdb$database"), '%%'),
+            (text("select '%%%' from rdb$database"), '%%%'),
+            (text("select 'hello % world' from rdb$database"), "hello % world")
+        ):
+            eq_(testing.db.scalar(expr), result)
 
index 86a4e751d41ab75d8752c76e16d2345380152a2e..e647990d3102900400232335e9d1244029825529 100644 (file)
@@ -4,7 +4,8 @@ from sqlalchemy.test import *
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
-    __dialect__ = informix.InfoDialect()
+    __only_on__ = 'informix'
+    __dialect__ = informix.InformixDialect()
     
     def test_statements(self):
         meta =MetaData()
index 033a05533f1aaf0aebf27b80b85d0afb6c1cd8c4..c69a81120f98c36d24184042dc149c25d0ef93fe 100644 (file)
@@ -185,7 +185,7 @@ class DBAPITest(TestBase, AssertsExecutionResults):
             vals = []
             for i in xrange(3):
                 cr.execute('SELECT busto.NEXTVAL FROM DUAL')
-                vals.append(cr.fetchone()[0])
+                vals.append(cr.first()[0])
 
             # should be 1,2,3, but no...
             self.assert_(vals != [1,2,3])
index dd86ce0de2822b1d1d4c9b77f85029520b1345b0..423310db62c274d0515de219dc3eaef73b6c7a09 100644 (file)
@@ -2,17 +2,18 @@
 from sqlalchemy.test.testing import eq_
 import datetime, os, re
 from sqlalchemy import *
-from sqlalchemy import types, exc
+from sqlalchemy import types, exc, schema
 from sqlalchemy.orm import *
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import mssql
-import sqlalchemy.engine.url as url
+from sqlalchemy.dialects.mssql import pyodbc
+from sqlalchemy.engine import url
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
-    __dialect__ = mssql.MSSQLDialect()
+    __dialect__ = mssql.dialect()
 
     def test_insert(self):
         t = table('sometable', column('somecolumn'))
@@ -157,6 +158,45 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 select([extract(field, t.c.col1)]),
                 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field)
 
+    def test_update_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name")
+
+        u = update(table1, values=dict(name='foo')).returning(table1)
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+                            "inserted.name, inserted.description")
+
+        u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar')
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+                            "inserted.name, inserted.description WHERE mytable.name = :name_1")
+        
+        u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1")
+
+    def test_insert_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)")
+
+        i = insert(table1, values=dict(name='foo')).returning(table1)
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, "
+                                "inserted.name, inserted.description VALUES (:name)")
+
+        i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)")
+
+
 
 class IdentityInsertTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
@@ -189,9 +229,9 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
         eq_([(9, 'Python')], list(cats))
 
         result = cattable.insert().values(description='PHP').execute()
-        eq_([10], result.last_inserted_ids())
+        eq_([10], result.inserted_primary_key)
         lastcat = cattable.select().order_by(desc(cattable.c.id)).execute()
-        eq_((10, 'PHP'), lastcat.fetchone())
+        eq_((10, 'PHP'), lastcat.first())
 
     def test_executemany(self):
         cattable.insert().execute([
@@ -213,10 +253,51 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
         eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
 
 
-class ReflectionTest(TestBase):
+class ReflectionTest(TestBase, ComparesTables):
     __only_on__ = 'mssql'
 
-    def testidentity(self):
+    def test_basic_reflection(self):
+        meta = MetaData(testing.db)
+
+        users = Table('engine_users', meta,
+            Column('user_id', types.INT, primary_key=True),
+            Column('user_name', types.VARCHAR(20), nullable=False),
+            Column('test1', types.CHAR(5), nullable=False),
+            Column('test2', types.Float(5), nullable=False),
+            Column('test3', types.Text),
+            Column('test4', types.Numeric, nullable = False),
+            Column('test5', types.DateTime),
+            Column('parent_user_id', types.Integer,
+                   ForeignKey('engine_users.user_id')),
+            Column('test6', types.DateTime, nullable=False),
+            Column('test7', types.Text),
+            Column('test8', types.Binary),
+            Column('test_passivedefault2', types.Integer, server_default='5'),
+            Column('test9', types.Binary(100)),
+            Column('test_numeric', types.Numeric()),
+            test_needs_fk=True,
+        )
+
+        addresses = Table('engine_email_addresses', meta,
+            Column('address_id', types.Integer, primary_key = True),
+            Column('remote_user_id', types.Integer, ForeignKey(users.c.user_id)),
+            Column('email_address', types.String(20)),
+            test_needs_fk=True,
+        )
+        meta.create_all()
+
+        try:
+            meta2 = MetaData()
+            reflected_users = Table('engine_users', meta2, autoload=True,
+                                    autoload_with=testing.db)
+            reflected_addresses = Table('engine_email_addresses', meta2,
+                                        autoload=True, autoload_with=testing.db)
+            self.assert_tables_equal(users, reflected_users)
+            self.assert_tables_equal(addresses, reflected_addresses)
+        finally:
+            meta.drop_all()
+
+    def test_identity(self):
         meta = MetaData(testing.db)
         table = Table(
             'identity_test', meta,
@@ -240,7 +321,7 @@ class QueryUnicodeTest(TestBase):
         meta = MetaData(testing.db)
         t1 = Table('unitest_table', meta,
                 Column('id', Integer, primary_key=True),
-                Column('descr', mssql.MSText(200, convert_unicode=True)))
+                Column('descr', mssql.MSText(convert_unicode=True)))
         meta.create_all()
         con = testing.db.connect()
 
@@ -248,7 +329,7 @@ class QueryUnicodeTest(TestBase):
         con.execute(u"insert into unitest_table values ('bien mangé')".encode('UTF-8'))
 
         try:
-            r = t1.select().execute().fetchone()
+            r = t1.select().execute().first()
             assert isinstance(r[1], unicode), '%s is %s instead of unicode, working on %s' % (
                     r[1], type(r[1]), meta.bind)
 
@@ -262,7 +343,9 @@ class QueryTest(TestBase):
         meta = MetaData(testing.db)
         t1 = Table('t1', meta,
                 Column('id', Integer, Sequence('fred', 100, 1), primary_key=True),
-                Column('descr', String(200)))
+                Column('descr', String(200)),
+                implicit_returning = False
+                )
         t2 = Table('t2', meta,
                 Column('id', Integer, Sequence('fred', 200, 1), primary_key=True),
                 Column('descr', String(200)))
@@ -274,9 +357,9 @@ class QueryTest(TestBase):
         try:
             tr = con.begin()
             r = con.execute(t2.insert(), descr='hello')
-            self.assert_(r.last_inserted_ids() == [200])
+            self.assert_(r.inserted_primary_key == [200])
             r = con.execute(t1.insert(), descr='hello')
-            self.assert_(r.last_inserted_ids() == [100])
+            self.assert_(r.inserted_primary_key == [100])
 
         finally:
             tr.commit()
@@ -295,6 +378,19 @@ class QueryTest(TestBase):
             tbl.drop()
             con.execute('drop schema paj')
 
+    def test_returning_no_autoinc(self):
+        meta = MetaData(testing.db)
+        
+        table = Table('t1', meta, Column('id', Integer, primary_key=True), Column('data', String(50)))
+        table.create()
+        try:
+            result = table.insert().values(id=1, data=func.lower("SomeString")).returning(table.c.id, table.c.data).execute()
+            eq_(result.fetchall(), [(1, 'somestring',)])
+        finally:
+            # this will hang if the "SET IDENTITY_INSERT t1 OFF" occurs before the
+            # result is fetched
+            table.drop()
+
     def test_delete_schema(self):
         meta = MetaData(testing.db)
         con = testing.db.connect()
@@ -371,36 +467,26 @@ class SchemaTest(TestBase):
         )
         self.column = t.c.test_column
 
+        dialect = mssql.dialect()
+        self.ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+    
+    def _column_spec(self):
+        return self.ddl_compiler.get_column_specification(self.column)
+        
     def test_that_mssql_default_nullability_emits_null(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR NULL", column_specification)
+        eq_("test_column VARCHAR NULL", self._column_spec())
 
     def test_that_mssql_none_nullability_does_not_emit_nullability(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
         self.column.nullable = None
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR", column_specification)
+        eq_("test_column VARCHAR", self._column_spec())
 
     def test_that_mssql_specified_nullable_emits_null(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
         self.column.nullable = True
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR NULL", column_specification)
+        eq_("test_column VARCHAR NULL", self._column_spec())
 
     def test_that_mssql_specified_not_nullable_emits_not_null(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
         self.column.nullable = False
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR NOT NULL", column_specification)
+        eq_("test_column VARCHAR NOT NULL", self._column_spec())
 
 
 def full_text_search_missing():
@@ -515,79 +601,73 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 class ParseConnectTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
 
+    @classmethod
+    def setup_class(cls):
+        global dialect
+        dialect = pyodbc.MSDialect_pyodbc()
+
     def test_pyodbc_connect_dsn_trusted(self):
         u = url.make_url('mssql://mydsn')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
 
     def test_pyodbc_connect_old_style_dsn_trusted(self):
         u = url.make_url('mssql:///?dsn=mydsn')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
 
     def test_pyodbc_connect_dsn_non_trusted(self):
         u = url.make_url('mssql://username:password@mydsn')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_connect_dsn_extra(self):
         u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
 
     def test_pyodbc_connect(self):
         u = url.make_url('mssql://username:password@hostspec/database')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_connect_comma_port(self):
         u = url.make_url('mssql://username:password@hostspec:12345/database')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_connect_config_port(self):
         u = url.make_url('mssql://username:password@hostspec/database?port=12345')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
 
     def test_pyodbc_extra_connect(self):
         u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
 
     def test_pyodbc_odbc_connect(self):
         u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_odbc_connect_with_dsn(self):
         u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_odbc_connect_ignores_other_values(self):
         u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
-        dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
 
-class TypesTest(TestBase):
+class TypesTest(TestBase, AssertsExecutionResults, ComparesTables):
     __only_on__ = 'mssql'
 
     @classmethod
     def setup_class(cls):
-        global numeric_table, metadata
+        global metadata
         metadata = MetaData(testing.db)
 
     def teardown(self):
@@ -601,26 +681,22 @@ class TypesTest(TestBase):
         )
         metadata.create_all()
 
-        try:
-            test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000',
-                          '-1500000.00000000000000000000', '1500000',
-                          '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2',
-                          '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234',
-                          '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3',
-                          '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2',
-                          '-02452E-2', '45125E-2',
-                          '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25',
-                          '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12',
-                          '00000000000000.1E+12', '000000000000.2E-32']
+        test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000',
+                      '-1500000.00000000000000000000', '1500000',
+                      '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2',
+                      '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234',
+                      '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3',
+                      '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2',
+                      '-02452E-2', '45125E-2',
+                      '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25',
+                      '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12',
+                      '00000000000000.1E+12', '000000000000.2E-32']
 
-            for value in test_items:
-                numeric_table.insert().execute(numericcol=value)
+        for value in test_items:
+            numeric_table.insert().execute(numericcol=value)
 
-            for value in select([numeric_table.c.numericcol]).execute():
-                assert value[0] in test_items, "%s not in test_items" % value[0]
-
-        except Exception, e:
-            raise e
+        for value in select([numeric_table.c.numericcol]).execute():
+            assert value[0] in test_items, "%s not in test_items" % value[0]
 
     def test_float(self):
         float_table = Table('float_table', metadata,
@@ -643,11 +719,6 @@ class TypesTest(TestBase):
             raise e
 
 
-class TypesTest2(TestBase, AssertsExecutionResults):
-    "Test Microsoft SQL Server column types"
-
-    __only_on__ = 'mssql'
-
     def test_money(self):
         "Exercise type specification for money types."
 
@@ -659,13 +730,14 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'SMALLMONEY'),
            ]
 
-        table_args = ['test_mssql_money', MetaData(testing.db)]
+        table_args = ['test_mssql_money', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         money_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        dialect = mssql.dialect()
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(money_table))
 
         for col in money_table.c:
             index = int(col.name[1:])
@@ -688,15 +760,27 @@ class TypesTest2(TestBase, AssertsExecutionResults):
             (mssql.MSDateTime, [], {},
              'DATETIME', []),
 
+            (types.DATE, [], {},
+             'DATE', ['>=', (10,)]),
+            (types.Date, [], {},
+             'DATE', ['>=', (10,)]),
+            (types.Date, [], {},
+             'DATETIME', ['<', (10,)], mssql.MSDateTime),
             (mssql.MSDate, [], {},
              'DATE', ['>=', (10,)]),
             (mssql.MSDate, [], {},
              'DATETIME', ['<', (10,)], mssql.MSDateTime),
 
+            (types.TIME, [], {},
+             'TIME', ['>=', (10,)]),
+            (types.Time, [], {},
+             'TIME', ['>=', (10,)]),
             (mssql.MSTime, [], {},
              'TIME', ['>=', (10,)]),
             (mssql.MSTime, [1], {},
              'TIME(1)', ['>=', (10,)]),
+            (types.Time, [], {},
+             'DATETIME', ['<', (10,)], mssql.MSDateTime),
             (mssql.MSTime, [], {},
              'DATETIME', ['<', (10,)], mssql.MSDateTime),
 
@@ -715,14 +799,14 @@ class TypesTest2(TestBase, AssertsExecutionResults):
 
             ]
 
-        table_args = ['test_mssql_dates', MetaData(testing.db)]
+        table_args = ['test_mssql_dates', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res, requires = spec[0:5]
             if (requires and testing._is_excluded('mssql', *requires)) or not requires:
                 table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         dates_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(dates_table))
 
         for col in dates_table.c:
             index = int(col.name[1:])
@@ -730,49 +814,37 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            dates_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
+        dates_table.create(checkfirst=True)
 
         reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True)
         for col in reflected_dates.c:
-            index = int(col.name[1:])
-            testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__,
-                len(columns[index]) > 5 and columns[index][5] or columns[index][0])
-        dates_table.drop()
-
-    def test_dates2(self):
-        meta = MetaData(testing.db)
-        t = Table('test_dates', meta,
-                  Column('id', Integer,
-                         Sequence('datetest_id_seq', optional=True),
-                         primary_key=True),
-                  Column('adate', Date),
-                  Column('atime', Time),
-                  Column('adatetime', DateTime))
-        t.create(checkfirst=True)
-        try:
-            d1 = datetime.date(2007, 10, 30)
-            t1 = datetime.time(11, 2, 32)
-            d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
-            t.insert().execute(adate=d1, adatetime=d2, atime=t1)
-            t.insert().execute(adate=d2, adatetime=d2, atime=d2)
+            self.assert_types_base(col, dates_table.c[col.key])
 
-            x = t.select().execute().fetchall()[0]
-            self.assert_(x.adate.__class__ == datetime.date)
-            self.assert_(x.atime.__class__ == datetime.time)
-            self.assert_(x.adatetime.__class__ == datetime.datetime)
+    def test_date_roundtrip(self):
+        t = Table('test_dates', metadata,
+                    Column('id', Integer,
+                           Sequence('datetest_id_seq', optional=True),
+                           primary_key=True),
+                    Column('adate', Date),
+                    Column('atime', Time),
+                    Column('adatetime', DateTime))
+        metadata.create_all()
+        d1 = datetime.date(2007, 10, 30)
+        t1 = datetime.time(11, 2, 32)
+        d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
+        t.insert().execute(adate=d1, adatetime=d2, atime=t1)
+        t.insert().execute(adate=d2, adatetime=d2, atime=d2)
 
-            t.delete().execute()
+        x = t.select().execute().fetchall()[0]
+        self.assert_(x.adate.__class__ == datetime.date)
+        self.assert_(x.atime.__class__ == datetime.time)
+        self.assert_(x.adatetime.__class__ == datetime.datetime)
 
-            t.insert().execute(adate=d1, adatetime=d2, atime=t1)
+        t.delete().execute()
 
-            eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
+        t.insert().execute(adate=d1, adatetime=d2, atime=t1)
 
-        finally:
-            t.drop(checkfirst=True)
+        eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
 
     def test_binary(self):
         "Exercise type specification for binary types."
@@ -781,6 +853,9 @@ class TypesTest2(TestBase, AssertsExecutionResults):
             # column type, args, kwargs, expected ddl
             (mssql.MSBinary, [], {},
              'BINARY'),
+            (types.Binary, [10], {},
+             'BINARY(10)'),
+
             (mssql.MSBinary, [10], {},
              'BINARY(10)'),
 
@@ -798,13 +873,14 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'BINARY(10)')
             ]
 
-        table_args = ['test_mssql_binary', MetaData(testing.db)]
+        table_args = ['test_mssql_binary', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         binary_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        dialect = mssql.dialect()
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table))
 
         for col in binary_table.c:
             index = int(col.name[1:])
@@ -812,22 +888,15 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            binary_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
+        metadata.create_all()
 
         reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True)
         for col in reflected_binary.c:
-            # don't test the MSGenericBinary since it's a special case and
-            # reflected it will map to a MSImage or MSBinary depending
-            if not testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ == mssql.MSGenericBinary:
-                testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__,
-                    testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__)
+            c1 =testing.db.dialect.type_descriptor(col.type).__class__
+            c2 =testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ 
+            assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2)
             if binary_table.c[col.name].type.length:
                 testing.eq_(col.type.length, binary_table.c[col.name].type.length)
-        binary_table.drop()
 
     def test_boolean(self):
         "Exercise type specification for boolean type."
@@ -838,13 +907,14 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'BIT'),
            ]
 
-        table_args = ['test_mssql_boolean', MetaData(testing.db)]
+        table_args = ['test_mssql_boolean', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         boolean_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        dialect = mssql.dialect()
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(boolean_table))
 
         for col in boolean_table.c:
             index = int(col.name[1:])
@@ -852,12 +922,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            boolean_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
-        boolean_table.drop()
+        metadata.create_all()
 
     def test_numeric(self):
         "Exercise type specification and options for numeric types."
@@ -865,40 +930,39 @@ class TypesTest2(TestBase, AssertsExecutionResults):
         columns = [
             # column type, args, kwargs, expected ddl
             (mssql.MSNumeric, [], {},
-             'NUMERIC(10, 2)'),
+             'NUMERIC'),
             (mssql.MSNumeric, [None], {},
              'NUMERIC'),
-            (mssql.MSNumeric, [12], {},
-             'NUMERIC(12, 2)'),
             (mssql.MSNumeric, [12, 4], {},
              'NUMERIC(12, 4)'),
 
-            (mssql.MSFloat, [], {},
-             'FLOAT(10)'),
-            (mssql.MSFloat, [None], {},
+            (types.Float, [], {},
+             'FLOAT'),
+            (types.Float, [None], {},
              'FLOAT'),
-            (mssql.MSFloat, [12], {},
+            (types.Float, [12], {},
              'FLOAT(12)'),
             (mssql.MSReal, [], {},
              'REAL'),
 
-            (mssql.MSInteger, [], {},
+            (types.Integer, [], {},
              'INTEGER'),
-            (mssql.MSBigInteger, [], {},
+            (types.BigInteger, [], {},
              'BIGINT'),
             (mssql.MSTinyInteger, [], {},
              'TINYINT'),
-            (mssql.MSSmallInteger, [], {},
+            (types.SmallInteger, [], {},
              'SMALLINT'),
            ]
 
-        table_args = ['test_mssql_numeric', MetaData(testing.db)]
+        table_args = ['test_mssql_numeric', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         numeric_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        dialect = mssql.dialect()
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(numeric_table))
 
         for col in numeric_table.c:
             index = int(col.name[1:])
@@ -906,20 +970,11 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            numeric_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
-        numeric_table.drop()
+        metadata.create_all()
 
     def test_char(self):
         """Exercise COLLATE-ish options on string types."""
 
-        # modify the text_as_varchar setting since we are not testing that behavior here
-        text_as_varchar = testing.db.dialect.text_as_varchar
-        testing.db.dialect.text_as_varchar = False
-
         columns = [
             (mssql.MSChar, [], {},
              'CHAR'),
@@ -960,13 +1015,14 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'NTEXT COLLATE Latin1_General_CI_AS'),
            ]
 
-        table_args = ['test_mssql_charset', MetaData(testing.db)]
+        table_args = ['test_mssql_charset', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         charset_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        dialect = mssql.dialect()
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(charset_table))
 
         for col in charset_table.c:
             index = int(col.name[1:])
@@ -974,110 +1030,91 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            charset_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
-        charset_table.drop()
-
-        testing.db.dialect.text_as_varchar = text_as_varchar
+        metadata.create_all()
 
     def test_timestamp(self):
         """Exercise TIMESTAMP column."""
 
-        meta = MetaData(testing.db)
-
-        try:
-            columns = [
-                (TIMESTAMP,
-                 'TIMESTAMP'),
-                (mssql.MSTimeStamp,
-                 'TIMESTAMP'),
-                ]
-            for idx, (spec, expected) in enumerate(columns):
-                t = Table('mssql_ts%s' % idx, meta,
-                          Column('id', Integer, primary_key=True),
-                          Column('t', spec, nullable=None))
-                testing.eq_(colspec(t.c.t), "t %s" % expected)
-                self.assert_(repr(t.c.t))
-                try:
-                    t.create(checkfirst=True)
-                    assert True
-                except:
-                    raise
-                t.drop()
-        finally:
-            meta.drop_all()
+        dialect = mssql.dialect()
 
+        spec, expected = (TIMESTAMP,'TIMESTAMP')
+        t = Table('mssql_ts', metadata,
+                   Column('id', Integer, primary_key=True),
+                   Column('t', spec, nullable=None))
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+        testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected)
+        self.assert_(repr(t.c.t))
+        t.create(checkfirst=True)
+        
     def test_autoincrement(self):
-        meta = MetaData(testing.db)
-        try:
-            Table('ai_1', meta,
-                  Column('int_y', Integer, primary_key=True),
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True))
-            Table('ai_2', meta,
-                  Column('int_y', Integer, primary_key=True),
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True))
-            Table('ai_3', meta,
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False),
-                  Column('int_y', Integer, primary_key=True))
-            Table('ai_4', meta,
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False),
-                  Column('int_n2', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False))
-            Table('ai_5', meta,
-                  Column('int_y', Integer, primary_key=True),
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False))
-            Table('ai_6', meta,
-                  Column('o1', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('int_y', Integer, primary_key=True))
-            Table('ai_7', meta,
-                  Column('o1', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('o2', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('int_y', Integer, primary_key=True))
-            Table('ai_8', meta,
-                  Column('o1', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('o2', String(1), DefaultClause('x'),
-                         primary_key=True))
-            meta.create_all()
-
-            table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
-                           'ai_5', 'ai_6', 'ai_7', 'ai_8']
-            mr = MetaData(testing.db)
-            mr.reflect(only=table_names)
-
-            for tbl in [mr.tables[name] for name in table_names]:
-                for c in tbl.c:
-                    if c.name.startswith('int_y'):
-                        assert c.autoincrement
-                    elif c.name.startswith('int_n'):
-                        assert not c.autoincrement
-                tbl.insert().execute()
+        Table('ai_1', metadata,
+               Column('int_y', Integer, primary_key=True),
+               Column('int_n', Integer, DefaultClause('0'),
+                      primary_key=True))
+        Table('ai_2', metadata,
+               Column('int_y', Integer, primary_key=True),
+               Column('int_n', Integer, DefaultClause('0'),
+                      primary_key=True))
+        Table('ai_3', metadata,
+               Column('int_n', Integer, DefaultClause('0'),
+                      primary_key=True, autoincrement=False),
+               Column('int_y', Integer, primary_key=True))
+        Table('ai_4', metadata,
+               Column('int_n', Integer, DefaultClause('0'),
+                      primary_key=True, autoincrement=False),
+               Column('int_n2', Integer, DefaultClause('0'),
+                      primary_key=True, autoincrement=False))
+        Table('ai_5', metadata,
+               Column('int_y', Integer, primary_key=True),
+               Column('int_n', Integer, DefaultClause('0'),
+                      primary_key=True, autoincrement=False))
+        Table('ai_6', metadata,
+               Column('o1', String(1), DefaultClause('x'),
+                      primary_key=True),
+               Column('int_y', Integer, primary_key=True))
+        Table('ai_7', metadata,
+               Column('o1', String(1), DefaultClause('x'),
+                      primary_key=True),
+               Column('o2', String(1), DefaultClause('x'),
+                      primary_key=True),
+               Column('int_y', Integer, primary_key=True))
+        Table('ai_8', metadata,
+               Column('o1', String(1), DefaultClause('x'),
+                      primary_key=True),
+               Column('o2', String(1), DefaultClause('x'),
+                      primary_key=True))
+        metadata.create_all()
+
+        table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
+                        'ai_5', 'ai_6', 'ai_7', 'ai_8']
+        mr = MetaData(testing.db)
+
+        for name in table_names:
+            tbl = Table(name, mr, autoload=True)
+            for c in tbl.c:
+                if c.name.startswith('int_y'):
+                    assert c.autoincrement
+                elif c.name.startswith('int_n'):
+                    assert not c.autoincrement
+            
+            for counter, engine in enumerate([
+                engines.testing_engine(options={'implicit_returning':False}),
+                engines.testing_engine(options={'implicit_returning':True}),
+                ]
+            ):
+                engine.execute(tbl.insert())
                 if 'int_y' in tbl.c:
-                    assert select([tbl.c.int_y]).scalar() == 1
-                    assert list(tbl.select().execute().fetchone()).count(1) == 1
+                    assert engine.scalar(select([tbl.c.int_y])) == counter + 1
+                    assert list(engine.execute(tbl.select()).first()).count(counter + 1) == 1
                 else:
-                    assert 1 not in list(tbl.select().execute().fetchone())
-        finally:
-            meta.drop_all()
-
-def colspec(c):
-    return testing.db.dialect.schemagenerator(testing.db.dialect,
-        testing.db, None, None).get_column_specification(c)
-
+                    assert 1 not in list(engine.execute(tbl.select()).first())
+                engine.execute(tbl.delete())
 
 class BinaryTest(TestBase, AssertsExecutionResults):
     """Test the Binary and VarBinary types"""
+    
+    __only_on__ = 'mssql'
+    
     @classmethod
     def setup_class(cls):
         global binary_table, MyPickleType
@@ -1125,6 +1162,11 @@ class BinaryTest(TestBase, AssertsExecutionResults):
         stream2 =self.load_stream('binary_data_two.dat')
         binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_image=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3)
         binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_image=stream2, data_slice=stream2[0:99], pickled=testobj2)
+        
+        # TODO: pyodbc does not seem to accept "None" for a VARBINARY column (data=None).
+        # error:  [Microsoft][ODBC SQL Server Driver][SQL Server]Implicit conversion from 
+        # data type varchar to varbinary is not allowed. Use the CONVERT function to run this query. (257)
+        #binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None)
         binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data_image=None, data_slice=stream2[0:99], pickled=None)
 
         for stmt in (
index 8adb2d71c53c20036cfb7dab7cdc8a1082ef7e3a..40526415221b2d70011a85151e684fa1ae80fd4b 100644 (file)
@@ -1,8 +1,12 @@
 from sqlalchemy.test.testing import eq_
+
+# Py2K
 import sets
+# end Py2K
+
 from sqlalchemy import *
 from sqlalchemy import sql, exc
-from sqlalchemy.databases import mysql
+from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.test.testing import eq_
 from sqlalchemy.test import *
 
@@ -56,11 +60,11 @@ class TypesTest(TestBase, AssertsExecutionResults):
             # column type, args, kwargs, expected ddl
             # e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED'
             (mysql.MSNumeric, [], {},
-             'NUMERIC(10, 2)'),
+             'NUMERIC'),
             (mysql.MSNumeric, [None], {},
              'NUMERIC'),
             (mysql.MSNumeric, [12], {},
-             'NUMERIC(12, 2)'),
+             'NUMERIC(12)'),
             (mysql.MSNumeric, [12, 4], {'unsigned':True},
              'NUMERIC(12, 4) UNSIGNED'),
             (mysql.MSNumeric, [12, 4], {'zerofill':True},
@@ -69,11 +73,11 @@ class TypesTest(TestBase, AssertsExecutionResults):
              'NUMERIC(12, 4) UNSIGNED ZEROFILL'),
 
             (mysql.MSDecimal, [], {},
-             'DECIMAL(10, 2)'),
+             'DECIMAL'),
             (mysql.MSDecimal, [None], {},
              'DECIMAL'),
             (mysql.MSDecimal, [12], {},
-             'DECIMAL(12, 2)'),
+             'DECIMAL(12)'),
             (mysql.MSDecimal, [12, None], {},
              'DECIMAL(12)'),
             (mysql.MSDecimal, [12, 4], {'unsigned':True},
@@ -178,11 +182,11 @@ class TypesTest(TestBase, AssertsExecutionResults):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         numeric_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, numeric_table)
 
         for col in numeric_table.c:
             index = int(col.name[1:])
-            self.assert_eq(gen.get_column_specification(col),
+            eq_(gen.get_column_specification(col),
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
@@ -262,11 +266,11 @@ class TypesTest(TestBase, AssertsExecutionResults):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         charset_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, charset_table)
 
         for col in charset_table.c:
             index = int(col.name[1:])
-            self.assert_eq(gen.get_column_specification(col),
+            eq_(gen.get_column_specification(col),
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
@@ -292,14 +296,14 @@ class TypesTest(TestBase, AssertsExecutionResults):
                           Column('b7', mysql.MSBit(63)),
                           Column('b8', mysql.MSBit(64)))
 
-        self.assert_eq(colspec(bit_table.c.b1), 'b1 BIT')
-        self.assert_eq(colspec(bit_table.c.b2), 'b2 BIT')
-        self.assert_eq(colspec(bit_table.c.b3), 'b3 BIT NOT NULL')
-        self.assert_eq(colspec(bit_table.c.b4), 'b4 BIT(1)')
-        self.assert_eq(colspec(bit_table.c.b5), 'b5 BIT(8)')
-        self.assert_eq(colspec(bit_table.c.b6), 'b6 BIT(32)')
-        self.assert_eq(colspec(bit_table.c.b7), 'b7 BIT(63)')
-        self.assert_eq(colspec(bit_table.c.b8), 'b8 BIT(64)')
+        eq_(colspec(bit_table.c.b1), 'b1 BIT')
+        eq_(colspec(bit_table.c.b2), 'b2 BIT')
+        eq_(colspec(bit_table.c.b3), 'b3 BIT NOT NULL')
+        eq_(colspec(bit_table.c.b4), 'b4 BIT(1)')
+        eq_(colspec(bit_table.c.b5), 'b5 BIT(8)')
+        eq_(colspec(bit_table.c.b6), 'b6 BIT(32)')
+        eq_(colspec(bit_table.c.b7), 'b7 BIT(63)')
+        eq_(colspec(bit_table.c.b8), 'b8 BIT(64)')
 
         for col in bit_table.c:
             self.assert_(repr(col))
@@ -314,7 +318,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 def roundtrip(store, expected=None):
                     expected = expected or store
                     table.insert(store).execute()
-                    row = list(table.select().execute())[0]
+                    row = table.select().execute().first()
                     try:
                         self.assert_(list(row) == expected)
                     except:
@@ -322,7 +326,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                         print "Expected %s" % expected
                         print "Found %s" % list(row)
                         raise
-                    table.delete().execute()
+                    table.delete().execute().close()
 
                 roundtrip([0] * 8)
                 roundtrip([None, None, 0, None, None, None, None, None])
@@ -350,10 +354,10 @@ class TypesTest(TestBase, AssertsExecutionResults):
                            Column('b3', mysql.MSTinyInteger(1)),
                            Column('b4', mysql.MSTinyInteger))
 
-        self.assert_eq(colspec(bool_table.c.b1), 'b1 BOOL')
-        self.assert_eq(colspec(bool_table.c.b2), 'b2 BOOL')
-        self.assert_eq(colspec(bool_table.c.b3), 'b3 TINYINT(1)')
-        self.assert_eq(colspec(bool_table.c.b4), 'b4 TINYINT')
+        eq_(colspec(bool_table.c.b1), 'b1 BOOL')
+        eq_(colspec(bool_table.c.b2), 'b2 BOOL')
+        eq_(colspec(bool_table.c.b3), 'b3 TINYINT(1)')
+        eq_(colspec(bool_table.c.b4), 'b4 TINYINT')
 
         for col in bool_table.c:
             self.assert_(repr(col))
@@ -364,7 +368,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
             def roundtrip(store, expected=None):
                 expected = expected or store
                 table.insert(store).execute()
-                row = list(table.select().execute())[0]
+                row = table.select().execute().first()
                 try:
                     self.assert_(list(row) == expected)
                     for i, val in enumerate(expected):
@@ -375,7 +379,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                     print "Expected %s" % expected
                     print "Found %s" % list(row)
                     raise
-                table.delete().execute()
+                table.delete().execute().close()
 
 
             roundtrip([None, None, None, None])
@@ -387,7 +391,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
             meta2 = MetaData(testing.db)
             # replace with reflected
             table = Table('mysql_bool', meta2, autoload=True)
-            self.assert_eq(colspec(table.c.b3), 'b3 BOOL')
+            eq_(colspec(table.c.b3), 'b3 BOOL')
 
             roundtrip([None, None, None, None])
             roundtrip([True, True, 1, 1], [True, True, True, 1])
@@ -430,7 +434,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 t = Table('mysql_ts%s' % idx, meta,
                           Column('id', Integer, primary_key=True),
                           Column('t', *spec))
-                self.assert_eq(colspec(t.c.t), "t %s" % expected)
+                eq_(colspec(t.c.t), "t %s" % expected)
                 self.assert_(repr(t.c.t))
                 t.create()
                 r = Table('mysql_ts%s' % idx, MetaData(testing.db),
@@ -460,12 +464,12 @@ class TypesTest(TestBase, AssertsExecutionResults):
 
             for table in year_table, reflected:
                 table.insert(['1950', '50', None, 50, 1950]).execute()
-                row = list(table.select().execute())[0]
-                self.assert_eq(list(row), [1950, 2050, None, 50, 1950])
+                row = table.select().execute().first()
+                eq_(list(row), [1950, 2050, None, 50, 1950])
                 table.delete().execute()
                 self.assert_(colspec(table.c.y1).startswith('y1 YEAR'))
-                self.assert_eq(colspec(table.c.y4), 'y4 YEAR(2)')
-                self.assert_eq(colspec(table.c.y5), 'y5 YEAR(4)')
+                eq_(colspec(table.c.y4), 'y4 YEAR(2)')
+                eq_(colspec(table.c.y5), 'y5 YEAR(4)')
         finally:
             meta.drop_all()
 
@@ -479,9 +483,9 @@ class TypesTest(TestBase, AssertsExecutionResults):
                           Column('s2', mysql.MSSet("'a'")),
                           Column('s3', mysql.MSSet("'5'", "'7'", "'9'")))
 
-        self.assert_eq(colspec(set_table.c.s1), "s1 SET('dq','sq')")
-        self.assert_eq(colspec(set_table.c.s2), "s2 SET('a')")
-        self.assert_eq(colspec(set_table.c.s3), "s3 SET('5','7','9')")
+        eq_(colspec(set_table.c.s1), "s1 SET('dq','sq')")
+        eq_(colspec(set_table.c.s2), "s2 SET('a')")
+        eq_(colspec(set_table.c.s3), "s3 SET('5','7','9')")
 
         for col in set_table.c:
             self.assert_(repr(col))
@@ -494,7 +498,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 def roundtrip(store, expected=None):
                     expected = expected or store
                     table.insert(store).execute()
-                    row = list(table.select().execute())[0]
+                    row = table.select().execute().first()
                     try:
                         self.assert_(list(row) == expected)
                     except:
@@ -518,12 +522,12 @@ class TypesTest(TestBase, AssertsExecutionResults):
                                        {'s3':set(['5', '7'])},
                                        {'s3':set(['5', '7', '9'])},
                                        {'s3':set(['7', '9'])})
-            rows = list(select(
+            rows = select(
                 [set_table.c.s3],
-                set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute())
+                set_table.c.s3.in_([set(['5']), set(['5', '7']), set(['7', '5'])])
+                ).execute().fetchall()
             found = set([frozenset(row[0]) for row in rows])
-            eq_(found,
-                              set([frozenset(['5']), frozenset(['5', '7'])]))
+            eq_(found, set([frozenset(['5']), frozenset(['5', '7'])]))
         finally:
             meta.drop_all()
 
@@ -542,17 +546,17 @@ class TypesTest(TestBase, AssertsExecutionResults):
             Column('e6', mysql.MSEnum("'a'", "b")),
             )
 
-        self.assert_eq(colspec(enum_table.c.e1),
+        eq_(colspec(enum_table.c.e1),
                        "e1 ENUM('a','b')")
-        self.assert_eq(colspec(enum_table.c.e2),
+        eq_(colspec(enum_table.c.e2),
                        "e2 ENUM('a','b') NOT NULL")
-        self.assert_eq(colspec(enum_table.c.e3),
+        eq_(colspec(enum_table.c.e3),
                        "e3 ENUM('a','b')")
-        self.assert_eq(colspec(enum_table.c.e4),
+        eq_(colspec(enum_table.c.e4),
                        "e4 ENUM('a','b') NOT NULL")
-        self.assert_eq(colspec(enum_table.c.e5),
+        eq_(colspec(enum_table.c.e5),
                        "e5 ENUM('a','b')")
-        self.assert_eq(colspec(enum_table.c.e6),
+        eq_(colspec(enum_table.c.e6),
                        "e6 ENUM('''a''','b')")
         enum_table.drop(checkfirst=True)
         enum_table.create()
@@ -585,8 +589,9 @@ class TypesTest(TestBase, AssertsExecutionResults):
         # This is known to fail with MySQLDB 1.2.2 beta versions
         # which return these as sets.Set(['a']), sets.Set(['b'])
         # (even on Pythons with __builtin__.set)
-        if testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
-           testing.db.dialect.dbapi.version_info >= (1, 2, 2):
+        if (not testing.against('+zxjdbc') and
+            testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and
+            testing.db.dialect.dbapi.version_info >= (1, 2, 2)):
             # these mysqldb seem to always uses 'sets', even on later pythons
             import sets
             def convert(value):
@@ -602,7 +607,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 e.append(tuple([convert(c) for c in row]))
             expected = e
 
-        self.assert_eq(res, expected)
+        eq_(res, expected)
         enum_table.drop()
 
     @testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''")
@@ -637,25 +642,52 @@ class TypesTest(TestBase, AssertsExecutionResults):
         finally:
             enum_table.drop()
 
+
+
+class ReflectionTest(TestBase, AssertsExecutionResults):
+
+    __only_on__ = 'mysql'
+
     def test_default_reflection(self):
         """Test reflection of column defaults."""
 
         def_table = Table('mysql_def', MetaData(testing.db),
             Column('c1', String(10), DefaultClause('')),
             Column('c2', String(10), DefaultClause('0')),
-            Column('c3', String(10), DefaultClause('abc')))
+            Column('c3', String(10), DefaultClause('abc')),
+            Column('c4', TIMESTAMP, DefaultClause('2009-04-05 12:00:00')),
+            Column('c5', TIMESTAMP, ),
+            
+        )
 
+        def_table.create()
         try:
-            def_table.create()
             reflected = Table('mysql_def', MetaData(testing.db),
-                              autoload=True)
-            for t in def_table, reflected:
-                assert t.c.c1.server_default.arg == ''
-                assert t.c.c2.server_default.arg == '0'
-                assert t.c.c3.server_default.arg == 'abc'
+                          autoload=True)
         finally:
             def_table.drop()
+        assert def_table.c.c1.server_default.arg == ''
+        assert def_table.c.c2.server_default.arg == '0'
+        assert def_table.c.c3.server_default.arg == 'abc'
+        assert def_table.c.c4.server_default.arg == '2009-04-05 12:00:00'
+
+        assert str(reflected.c.c1.server_default.arg) == "''"
+        assert str(reflected.c.c2.server_default.arg) == "'0'"
+        assert str(reflected.c.c3.server_default.arg) == "'abc'"
+        assert str(reflected.c.c4.server_default.arg) == "'2009-04-05 12:00:00'"
+            
+        reflected.create()
+        try:
+            reflected2 = Table('mysql_def', MetaData(testing.db), autoload=True)
+        finally:
+            reflected.drop()
 
+        assert str(reflected2.c.c1.server_default.arg) == "''"
+        assert str(reflected2.c.c2.server_default.arg) == "'0'"
+        assert str(reflected2.c.c3.server_default.arg) == "'abc'"
+        assert str(reflected2.c.c4.server_default.arg) == "'2009-04-05 12:00:00'"
+            
     def test_reflection_on_include_columns(self):
         """Test reflection of include_columns to be sure they respect case."""
 
@@ -700,8 +732,8 @@ class TypesTest(TestBase, AssertsExecutionResults):
                  ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ),
                  ( mysql.MSMediumInteger(), mysql.MSMediumInteger(), ),
                  ( mysql.MSMediumInteger(8), mysql.MSMediumInteger(8), ),
-                 ( Binary(3), mysql.MSBlob(3), ),
-                 ( Binary(), mysql.MSBlob() ),
+                 ( Binary(3), mysql.TINYBLOB(), ),
+                 ( Binary(), mysql.BLOB() ),
                  ( mysql.MSBinary(3), mysql.MSBinary(3), ),
                  ( mysql.MSVarBinary(3),),
                  ( mysql.MSVarBinary(), mysql.MSBlob()),
@@ -734,14 +766,15 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 # in a view, e.g. char -> varchar, tinyblob -> mediumblob
                 #
                 # Not sure exactly which point version has the fix.
-                if db.dialect.server_version_info(db.connect()) < (5, 0, 11):
+                if db.dialect.server_version_info < (5, 0, 11):
                     tables = rt,
                 else:
                     tables = rt, rv
 
                 for table in tables:
                     for i, reflected in enumerate(table.c):
-                        assert isinstance(reflected.type, type(expected[i]))
+                        assert isinstance(reflected.type, type(expected[i])), \
+                                "element %d: %r not instance of %r" % (i, reflected.type, type(expected[i]))
             finally:
                 db.execute('DROP VIEW mysql_types_v')
         finally:
@@ -802,17 +835,12 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 tbl.insert().execute()
                 if 'int_y' in tbl.c:
                     assert select([tbl.c.int_y]).scalar() == 1
-                    assert list(tbl.select().execute().fetchone()).count(1) == 1
+                    assert list(tbl.select().execute().first()).count(1) == 1
                 else:
-                    assert 1 not in list(tbl.select().execute().fetchone())
+                    assert 1 not in list(tbl.select().execute().first())
         finally:
             meta.drop_all()
 
-    def assert_eq(self, got, wanted):
-        if got != wanted:
-            print "Expected %s" % wanted
-            print "Found %s" % got
-        eq_(got, wanted)
 
 
 class SQLTest(TestBase, AssertsCompiledSQL):
@@ -909,11 +937,11 @@ class SQLTest(TestBase, AssertsCompiledSQL):
             (m.MSBit, "t.col"),
 
             # this is kind of sucky.  thank you default arguments!
-            (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"),
-            (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"),
-            (Numeric, "CAST(t.col AS DECIMAL(10, 2))"),
-            (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"),
-            (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"),
+            (NUMERIC, "CAST(t.col AS DECIMAL)"),
+            (DECIMAL, "CAST(t.col AS DECIMAL)"),
+            (Numeric, "CAST(t.col AS DECIMAL)"),
+            (m.MSNumeric, "CAST(t.col AS DECIMAL)"),
+            (m.MSDecimal, "CAST(t.col AS DECIMAL)"),
 
             (FLOAT, "t.col"),
             (Float, "t.col"),
@@ -928,8 +956,8 @@ class SQLTest(TestBase, AssertsCompiledSQL):
             (DateTime, "CAST(t.col AS DATETIME)"),
             (Date, "CAST(t.col AS DATE)"),
             (Time, "CAST(t.col AS TIME)"),
-            (m.MSDateTime, "CAST(t.col AS DATETIME)"),
-            (m.MSDate, "CAST(t.col AS DATE)"),
+            (DateTime, "CAST(t.col AS DATETIME)"),
+            (Date, "CAST(t.col AS DATE)"),
             (m.MSTime, "CAST(t.col AS TIME)"),
             (m.MSTimeStamp, "CAST(t.col AS DATETIME)"),
             (m.MSYear, "t.col"),
@@ -998,12 +1026,11 @@ class SQLTest(TestBase, AssertsCompiledSQL):
 
 class RawReflectionTest(TestBase):
     def setup(self):
-        self.dialect = mysql.dialect()
-        self.reflector = mysql.MySQLSchemaReflector(
-            self.dialect.identifier_preparer)
+        dialect = mysql.dialect()
+        self.parser = mysql.MySQLTableDefinitionParser(dialect, dialect.identifier_preparer)
 
     def test_key_reflection(self):
-        regex = self.reflector._re_key
+        regex = self.parser._re_key
 
         assert regex.match('  PRIMARY KEY (`id`),')
         assert regex.match('  PRIMARY KEY USING BTREE (`id`),')
@@ -1023,37 +1050,11 @@ class ExecutionTest(TestBase):
 
         cx = engine.connect()
         meta = MetaData()
-
-        assert ('mysql', 'charset') not in cx.info
-        assert ('mysql', 'force_charset') not in cx.info
-
-        cx.execute(text("SELECT 1")).fetchall()
-        assert ('mysql', 'charset') not in cx.info
-
-        meta.reflect(cx)
-        assert ('mysql', 'charset') in cx.info
-
-        cx.execute(text("SET @squiznart=123"))
-        assert ('mysql', 'charset') in cx.info
-
-        # the charset invalidation is very conservative
-        cx.execute(text("SET TIMESTAMP = DEFAULT"))
-        assert ('mysql', 'charset') not in cx.info
-
-        cx.info[('mysql', 'force_charset')] = 'latin1'
-
-        assert engine.dialect._detect_charset(cx) == 'latin1'
-        assert cx.info[('mysql', 'charset')] == 'latin1'
-
-        del cx.info[('mysql', 'force_charset')]
-        del cx.info[('mysql', 'charset')]
+        charset = engine.dialect._detect_charset(cx)
 
         meta.reflect(cx)
-        assert ('mysql', 'charset') in cx.info
-
-        # String execution doesn't go through the detector.
-        cx.execute("SET TIMESTAMP = DEFAULT")
-        assert ('mysql', 'charset') in cx.info
+        eq_(cx.dialect._connection_charset, charset)
+        cx.close()
 
 
 class MatchTest(TestBase, AssertsCompiledSQL):
@@ -1102,9 +1103,10 @@ class MatchTest(TestBase, AssertsCompiledSQL):
         metadata.drop_all()
 
     def test_expression(self):
+        format = testing.db.dialect.paramstyle == 'format' and '%s' or '?'
         self.assert_compile(
             matchtable.c.title.match('somstr'),
-            "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)")
+            "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)" % format)
 
     def test_simple_match(self):
         results = (matchtable.select().
@@ -1162,6 +1164,5 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 
 
 def colspec(c):
-    return testing.db.dialect.schemagenerator(testing.db.dialect,
-        testing.db, None, None).get_column_specification(c)
+    return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c)
 
index d9d64806e814d21efc439516cdbc5dd67f7c2889..53e0f9ec2f4c0da214edc155f9a1198006f051f3 100644 (file)
@@ -2,12 +2,14 @@
 
 from sqlalchemy.test.testing import eq_
 from sqlalchemy import *
+from sqlalchemy import types as sqltypes
 from sqlalchemy.sql import table, column
-from sqlalchemy.databases import oracle
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_
 from sqlalchemy.test.engines import testing_engine
+from sqlalchemy.dialects.oracle import cx_oracle, base as oracle
 from sqlalchemy.engine import default
+from sqlalchemy.util import jython
 import os
 
 
@@ -43,10 +45,10 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         meta  = MetaData()
         parent = Table('parent', meta, Column('id', Integer, primary_key=True), 
            Column('name', String(50)),
-           owner='ed')
+           schema='ed')
         child = Table('child', meta, Column('id', Integer, primary_key=True),
            Column('parent_id', Integer, ForeignKey('ed.parent.id')),
-           owner = 'ed')
+           schema = 'ed')
 
         self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id")
 
@@ -342,6 +344,25 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         b = bindparam("foo", u"hello world!")
         assert b.type.dialect_impl(dialect).get_dbapi_type(dbapi) == 'STRING'
 
+    def test_type_adapt(self):
+        dialect = cx_oracle.dialect()
+
+        for start, test in [
+            (DateTime(), cx_oracle._OracleDateTime),
+            (TIMESTAMP(), cx_oracle._OracleTimestamp),
+            (oracle.OracleRaw(), cx_oracle._OracleRaw),
+            (String(), String),
+            (VARCHAR(), VARCHAR),
+            (String(50), String),
+            (Unicode(), Unicode),
+            (Text(), cx_oracle._OracleText),
+            (UnicodeText(), cx_oracle._OracleUnicodeText),
+            (NCHAR(), NCHAR),
+            (oracle.RAW(50), cx_oracle._OracleRaw),
+        ]:
+            assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
+
+
     def test_reflect_raw(self):
         types_table = Table(
         'all_types', MetaData(testing.db),
@@ -354,16 +375,16 @@ class TypesTest(TestBase, AssertsCompiledSQL):
     def test_reflect_nvarchar(self):
         metadata = MetaData(testing.db)
         t = Table('t', metadata,
-            Column('data', oracle.OracleNVarchar(255))
+            Column('data', sqltypes.NVARCHAR(255))
         )
         metadata.create_all()
         try:
             m2 = MetaData(testing.db)
             t2 = Table('t', m2, autoload=True)
-            assert isinstance(t2.c.data.type, oracle.OracleNVarchar)
+            assert isinstance(t2.c.data.type, sqltypes.NVARCHAR)
             data = u'm’a réveillé.'
             t2.insert().execute(data=data)
-            eq_(t2.select().execute().fetchone()['data'], data)
+            eq_(t2.select().execute().first()['data'], data)
         finally:
             metadata.drop_all()
         
@@ -391,7 +412,7 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         t.create(engine)
         try:
             engine.execute(t.insert(), id=1, data='this is text', bindata='this is binary')
-            row = engine.execute(t.select()).fetchone()
+            row = engine.execute(t.select()).first()
             eq_(row['data'].read(), 'this is text')
             eq_(row['bindata'].read(), 'this is binary')
         finally:
@@ -408,7 +429,6 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL):
            Column('data', Binary)
         )
         meta.create_all()
-        
         stream = os.path.join(os.path.dirname(__file__), "..", 'binary_data_one.dat')
         stream = file(stream).read(12000)
 
@@ -420,17 +440,18 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL):
         meta.drop_all()
 
     def test_fetch(self):
-        eq_(
-            binary_table.select().execute().fetchall() ,
-            [(i, stream) for i in range(1, 11)], 
-        )
+        result = binary_table.select().execute().fetchall()
+        if jython:
+            result = [(i, value.tostring()) for i, value in result]
+        eq_(result, [(i, stream) for i in range(1, 11)])
 
+    @testing.fails_on('+zxjdbc', 'FIXME: zxjdbc should support this')
     def test_fetch_single_arraysize(self):
         eng = testing_engine(options={'arraysize':1})
-        eq_(
-            eng.execute(binary_table.select()).fetchall(),
-            [(i, stream) for i in range(1, 11)], 
-        )
+        result = eng.execute(binary_table.select()).fetchall(),
+        if jython:
+            result = [(i, value.tostring()) for i, value in result]
+        eq_(result, [(i, stream) for i in range(1, 11)])
 
 class SequenceTest(TestBase, AssertsCompiledSQL):
     def test_basic(self):
similarity index 66%
rename from test/dialect/test_postgres.py
rename to test/dialect/test_postgresql.py
index 8ca714badc79033c9441c59562a87a369f46a271..e1c351a93ec50b4553b6ff802c1717cc262ef6c6 100644 (file)
@@ -1,18 +1,19 @@
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test import  engines
 import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exc
-from sqlalchemy.databases import postgres
+from sqlalchemy import exc, schema
+from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.engine.strategies import MockEngineStrategy
 from sqlalchemy.test import *
 from sqlalchemy.sql import table, column
-
+from sqlalchemy.test.testing import eq_
 
 class SequenceTest(TestBase, AssertsCompiledSQL):
     def test_basic(self):
         seq = Sequence("my_seq_no_schema")
-        dialect = postgres.PGDialect()
+        dialect = postgresql.PGDialect()
         assert dialect.identifier_preparer.format_sequence(seq) == "my_seq_no_schema"
 
         seq = Sequence("my_seq", schema="some_schema")
@@ -22,43 +23,77 @@ class SequenceTest(TestBase, AssertsCompiledSQL):
         assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"'
 
 class CompileTest(TestBase, AssertsCompiledSQL):
-    __dialect__ = postgres.dialect()
+    __dialect__ = postgresql.dialect()
 
     def test_update_returning(self):
-        dialect = postgres.dialect()
+        dialect = postgresql.dialect()
         table1 = table('mytable',
             column('myid', Integer),
             column('name', String(128)),
             column('description', String(128)),
         )
 
-        u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+        u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
 
-        u = update(table1, values=dict(name='foo'), postgres_returning=[table1])
+        u = update(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\
             "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
 
-        u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
-        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
+        u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect)
 
+        
     def test_insert_returning(self):
-        dialect = postgres.dialect()
+        dialect = postgresql.dialect()
         table1 = table('mytable',
             column('myid', Integer),
             column('name', String(128)),
             column('description', String(128)),
         )
 
-        i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+        i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
 
-        i = insert(table1, values=dict(name='foo'), postgres_returning=[table1])
+        i = insert(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\
             "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
 
-        i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
-        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
+        i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect)
+    
+    @testing.uses_deprecated(r".*argument is deprecated.  Please use statement.returning.*")
+    def test_old_returning_names(self):
+        dialect = postgresql.dialect()
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+
+        u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+
+        i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
+        
+    def test_create_partial_index(self):
+        tbl = Table('testtbl', MetaData(), Column('data',Integer))
+        idx = Index('test_idx1', tbl.c.data, postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+        self.assert_compile(schema.CreateIndex(idx), 
+            "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect())
+
+    @testing.uses_deprecated(r".*'postgres_where' argument has been renamed.*")
+    def test_old_create_partial_index(self):
+        tbl = Table('testtbl', MetaData(), Column('data',Integer))
+        idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+        self.assert_compile(schema.CreateIndex(idx), 
+            "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect())
 
     def test_extract(self):
         t = table('t', column('col1'))
@@ -70,72 +105,20 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 "FROM t" % field)
 
 
-class ReturningTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
-
-    @testing.exclude('postgres', '<', (8, 2), '8.3+ feature')
-    def test_update_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
-            result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute()
-            eq_(result.fetchall(), [(1,)])
-
-            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            eq_(result2.fetchall(), [(1,True),(2,False)])
-        finally:
-            table.drop()
-
-    @testing.exclude('postgres', '<', (8, 2), '8.3+ feature')
-    def test_insert_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
-            eq_(result.fetchall(), [(1,)])
-
-            @testing.fails_on('postgres', 'Known limitation of psycopg2')
-            def test_executemany():
-                # return value is documented as failing with psycopg2/executemany
-                result2 = table.insert(postgres_returning=[table]).execute(
-                     [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-                eq_(result2.fetchall(), [(2, 2, False), (3,3,True)])
-            
-            test_executemany()
-            
-            result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
-            eq_([dict(row) for row in result3], [{'double_id':8}])
-
-            result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
-            eq_([dict(row) for row in result4], [{'persons': 10}])
-        finally:
-            table.drop()
-
-
 class InsertTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql'
 
     @classmethod
     def setup_class(cls):
         global metadata
+        cls.engine= testing.db
         metadata = MetaData(testing.db)
 
     def teardown(self):
         metadata.drop_all()
         metadata.tables.clear()
+        if self.engine is not testing.db:
+            self.engine.dispose()
 
     def test_compiled_insert(self):
         table = Table('testtable', metadata,
@@ -144,7 +127,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
         metadata.create_all()
 
-        ins = table.insert(values={'data':bindparam('x')}).compile()
+        ins = table.insert(inline=True, values={'data':bindparam('x')}).compile()
         ins.execute({'x':"five"}, {'x':"seven"})
         assert table.select().execute().fetchall() == [(1, 'five'), (2, 'seven')]
 
@@ -155,6 +138,13 @@ class InsertTest(TestBase, AssertsExecutionResults):
         metadata.create_all()
         self._assert_data_with_sequence(table, "my_seq")
 
+    def test_sequence_returning_insert(self):
+        table = Table('testtable', metadata,
+            Column('id', Integer, Sequence('my_seq'), primary_key=True),
+            Column('data', String(30)))
+        metadata.create_all()
+        self._assert_data_with_sequence_returning(table, "my_seq")
+
     def test_opt_sequence_insert(self):
         table = Table('testtable', metadata,
             Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True),
@@ -162,6 +152,13 @@ class InsertTest(TestBase, AssertsExecutionResults):
         metadata.create_all()
         self._assert_data_autoincrement(table)
 
+    def test_opt_sequence_returning_insert(self):
+        table = Table('testtable', metadata,
+            Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True),
+            Column('data', String(30)))
+        metadata.create_all()
+        self._assert_data_autoincrement_returning(table)
+
     def test_autoincrement_insert(self):
         table = Table('testtable', metadata,
             Column('id', Integer, primary_key=True),
@@ -169,6 +166,13 @@ class InsertTest(TestBase, AssertsExecutionResults):
         metadata.create_all()
         self._assert_data_autoincrement(table)
 
+    def test_autoincrement_returning_insert(self):
+        table = Table('testtable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)))
+        metadata.create_all()
+        self._assert_data_autoincrement_returning(table)
+
     def test_noautoincrement_insert(self):
         table = Table('testtable', metadata,
             Column('id', Integer, primary_key=True, autoincrement=False),
@@ -177,14 +181,17 @@ class InsertTest(TestBase, AssertsExecutionResults):
         self._assert_data_noautoincrement(table)
 
     def _assert_data_autoincrement(self, table):
+        self.engine = engines.testing_engine(options={'implicit_returning':False})
+        metadata.bind = self.engine
+
         def go():
             # execute with explicit id
             r = table.insert().execute({'id':30, 'data':'d1'})
-            assert r.last_inserted_ids() == [30]
+            assert r.inserted_primary_key == [30]
 
             # execute with prefetch id
             r = table.insert().execute({'data':'d2'})
-            assert r.last_inserted_ids() == [1]
+            assert r.inserted_primary_key == [1]
 
             # executemany with explicit ids
             table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
@@ -201,7 +208,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
         # note that the test framework doesnt capture the "preexecute" of a seqeuence
         # or default.  we just see it in the bind params.
 
-        self.assert_sql(testing.db, go, [], with_sequences=[
+        self.assert_sql(self.engine, go, [], with_sequences=[
             (
                 "INSERT INTO testtable (id, data) VALUES (:id, :data)",
                 {'id':30, 'data':'d1'}
@@ -242,19 +249,19 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
         # test the same series of events using a reflected
         # version of the table
-        m2 = MetaData(testing.db)
+        m2 = MetaData(self.engine)
         table = Table(table.name, m2, autoload=True)
 
         def go():
             table.insert().execute({'id':30, 'data':'d1'})
             r = table.insert().execute({'data':'d2'})
-            assert r.last_inserted_ids() == [5]
+            assert r.inserted_primary_key == [5]
             table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
             table.insert().execute({'data':'d5'}, {'data':'d6'})
             table.insert(inline=True).execute({'id':33, 'data':'d7'})
             table.insert(inline=True).execute({'data':'d8'})
 
-        self.assert_sql(testing.db, go, [], with_sequences=[
+        self.assert_sql(self.engine, go, [], with_sequences=[
             (
                 "INSERT INTO testtable (id, data) VALUES (:id, :data)",
                 {'id':30, 'data':'d1'}
@@ -293,7 +300,127 @@ class InsertTest(TestBase, AssertsExecutionResults):
         ]
         table.delete().execute()
 
+    def _assert_data_autoincrement_returning(self, table):
+        self.engine = engines.testing_engine(options={'implicit_returning':True})
+        metadata.bind = self.engine
+
+        def go():
+            # execute with explicit id
+            r = table.insert().execute({'id':30, 'data':'d1'})
+            assert r.inserted_primary_key == [30]
+
+            # execute with prefetch id
+            r = table.insert().execute({'data':'d2'})
+            assert r.inserted_primary_key == [1]
+
+            # executemany with explicit ids
+            table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+
+            # executemany, uses SERIAL
+            table.insert().execute({'data':'d5'}, {'data':'d6'})
+
+            # single execute, explicit id, inline
+            table.insert(inline=True).execute({'id':33, 'data':'d7'})
+
+            # single execute, inline, uses SERIAL
+            table.insert(inline=True).execute({'data':'d8'})
+        
+        self.assert_sql(self.engine, go, [], with_sequences=[
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                {'id':30, 'data':'d1'}
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id",
+                {'data': 'd2'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d5'}, {'data':'d6'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':33, 'data':'d7'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d8'}]
+            ),
+        ])
+
+        assert table.select().execute().fetchall() == [
+            (30, 'd1'),
+            (1, 'd2'),
+            (31, 'd3'),
+            (32, 'd4'),
+            (2, 'd5'),
+            (3, 'd6'),
+            (33, 'd7'),
+            (4, 'd8'),
+        ]
+        table.delete().execute()
+
+        # test the same series of events using a reflected
+        # version of the table
+        m2 = MetaData(self.engine)
+        table = Table(table.name, m2, autoload=True)
+
+        def go():
+            table.insert().execute({'id':30, 'data':'d1'})
+            r = table.insert().execute({'data':'d2'})
+            assert r.inserted_primary_key == [5]
+            table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+            table.insert().execute({'data':'d5'}, {'data':'d6'})
+            table.insert(inline=True).execute({'id':33, 'data':'d7'})
+            table.insert(inline=True).execute({'data':'d8'})
+
+        self.assert_sql(self.engine, go, [], with_sequences=[
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                {'id':30, 'data':'d1'}
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id",
+                {'data':'d2'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d5'}, {'data':'d6'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':33, 'data':'d7'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d8'}]
+            ),
+        ])
+
+        assert table.select().execute().fetchall() == [
+            (30, 'd1'),
+            (5, 'd2'),
+            (31, 'd3'),
+            (32, 'd4'),
+            (6, 'd5'),
+            (7, 'd6'),
+            (33, 'd7'),
+            (8, 'd8'),
+        ]
+        table.delete().execute()
+
     def _assert_data_with_sequence(self, table, seqname):
+        self.engine = engines.testing_engine(options={'implicit_returning':False})
+        metadata.bind = self.engine
+
         def go():
             table.insert().execute({'id':30, 'data':'d1'})
             table.insert().execute({'data':'d2'})
@@ -302,7 +429,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
             table.insert(inline=True).execute({'id':33, 'data':'d7'})
             table.insert(inline=True).execute({'data':'d8'})
 
-        self.assert_sql(testing.db, go, [], with_sequences=[
+        self.assert_sql(self.engine, go, [], with_sequences=[
             (
                 "INSERT INTO testtable (id, data) VALUES (:id, :data)",
                 {'id':30, 'data':'d1'}
@@ -343,18 +470,76 @@ class InsertTest(TestBase, AssertsExecutionResults):
         # cant test reflection here since the Sequence must be
         # explicitly specified
 
+    def _assert_data_with_sequence_returning(self, table, seqname):
+        self.engine = engines.testing_engine(options={'implicit_returning':True})
+        metadata.bind = self.engine
+
+        def go():
+            table.insert().execute({'id':30, 'data':'d1'})
+            table.insert().execute({'data':'d2'})
+            table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+            table.insert().execute({'data':'d5'}, {'data':'d6'})
+            table.insert(inline=True).execute({'id':33, 'data':'d7'})
+            table.insert(inline=True).execute({'data':'d8'})
+
+        self.assert_sql(self.engine, go, [], with_sequences=[
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                {'id':30, 'data':'d1'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (nextval('my_seq'), :data) RETURNING testtable.id",
+                {'data':'d2'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname,
+                [{'data':'d5'}, {'data':'d6'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':33, 'data':'d7'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname,
+                [{'data':'d8'}]
+            ),
+        ])
+
+        assert table.select().execute().fetchall() == [
+            (30, 'd1'),
+            (1, 'd2'),
+            (31, 'd3'),
+            (32, 'd4'),
+            (2, 'd5'),
+            (3, 'd6'),
+            (33, 'd7'),
+            (4, 'd8'),
+        ]
+
+        # cant test reflection here since the Sequence must be
+        # explicitly specified
+
     def _assert_data_noautoincrement(self, table):
+        self.engine = engines.testing_engine(options={'implicit_returning':False})
+        metadata.bind = self.engine
+
         table.insert().execute({'id':30, 'data':'d1'})
-        try:
-            table.insert().execute({'data':'d2'})
-            assert False
-        except exc.IntegrityError, e:
-            assert "violates not-null constraint" in str(e)
-        try:
-            table.insert().execute({'data':'d2'}, {'data':'d3'})
-            assert False
-        except exc.IntegrityError, e:
-            assert "violates not-null constraint" in str(e)
+        
+        if self.engine.driver == 'pg8000':
+            exception_cls = exc.ProgrammingError
+        else:
+            exception_cls = exc.IntegrityError
+        
+        assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+        assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
+
+        assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+
+        assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
 
         table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
         table.insert(inline=True).execute({'id':33, 'data':'d4'})
@@ -369,19 +554,12 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
         # test the same series of events using a reflected
         # version of the table
-        m2 = MetaData(testing.db)
+        m2 = MetaData(self.engine)
         table = Table(table.name, m2, autoload=True)
         table.insert().execute({'id':30, 'data':'d1'})
-        try:
-            table.insert().execute({'data':'d2'})
-            assert False
-        except exc.IntegrityError, e:
-            assert "violates not-null constraint" in str(e)
-        try:
-            table.insert().execute({'data':'d2'}, {'data':'d3'})
-            assert False
-        except exc.IntegrityError, e:
-            assert "violates not-null constraint" in str(e)
+
+        assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+        assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
 
         table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
         table.insert(inline=True).execute({'id':33, 'data':'d4'})
@@ -396,36 +574,36 @@ class InsertTest(TestBase, AssertsExecutionResults):
 class DomainReflectionTest(TestBase, AssertsExecutionResults):
     "Test PostgreSQL domains"
 
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql'
 
     @classmethod
     def setup_class(cls):
         con = testing.db.connect()
         for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42',
-                    'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'):
+                    'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0'):
             try:
                 con.execute(ddl)
             except exc.SQLError, e:
                 if not "already exists" in str(e):
                     raise e
         con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
-        con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
-        con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
+        con.execute('CREATE TABLE test_schema.testtable(question integer, answer test_schema.testdomain, anything integer)')
+        con.execute('CREATE TABLE crosschema (question integer, answer test_schema.testdomain)')
 
     @classmethod
     def teardown_class(cls):
         con = testing.db.connect()
         con.execute('DROP TABLE testtable')
-        con.execute('DROP TABLE alt_schema.testtable')
+        con.execute('DROP TABLE test_schema.testtable')
         con.execute('DROP TABLE crosschema')
         con.execute('DROP DOMAIN testdomain')
-        con.execute('DROP DOMAIN alt_schema.testdomain')
+        con.execute('DROP DOMAIN test_schema.testdomain')
 
     def test_table_is_reflected(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True)
         eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
-        eq_(table.c.answer.type.__class__, postgres.PGInteger)
+        assert isinstance(table.c.answer.type, Integer)
 
     def test_domain_is_reflected(self):
         metadata = MetaData(testing.db)
@@ -433,15 +611,15 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
         assert not table.columns.answer.nullable, "Expected reflected column to not be nullable."
 
-    def test_table_is_reflected_alt_schema(self):
+    def test_table_is_reflected_test_schema(self):
         metadata = MetaData(testing.db)
-        table = Table('testtable', metadata, autoload=True, schema='alt_schema')
+        table = Table('testtable', metadata, autoload=True, schema='test_schema')
         eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
-        eq_(table.c.anything.type.__class__, postgres.PGInteger)
+        assert isinstance(table.c.anything.type, Integer)
 
     def test_schema_domain_is_reflected(self):
         metadata = MetaData(testing.db)
-        table = Table('testtable', metadata, autoload=True, schema='alt_schema')
+        table = Table('testtable', metadata, autoload=True, schema='test_schema')
         eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
         assert table.columns.answer.nullable, "Expected reflected column to be nullable."
 
@@ -452,10 +630,10 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         assert table.columns.answer.nullable, "Expected reflected column to be nullable."
 
     def test_unknown_types(self):
-        from sqlalchemy.databases import postgres
+        from sqlalchemy.databases import postgresql
 
-        ischema_names = postgres.ischema_names
-        postgres.ischema_names = {}
+        ischema_names = postgresql.PGDialect.ischema_names
+        postgresql.PGDialect.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
             assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True)
@@ -467,11 +645,11 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
                 assert t3.c.answer.type.__class__ == sa.types.NullType
 
         finally:
-            postgres.ischema_names = ischema_names
+            postgresql.PGDialect.ischema_names = ischema_names
 
 
-class MiscTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
+class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
+    __only_on__ = 'postgresql'
 
     def test_date_reflection(self):
         m1 = MetaData(testing.db)
@@ -536,26 +714,26 @@ class MiscTest(TestBase, AssertsExecutionResults):
             'FROM mytable')
 
     def test_schema_reflection(self):
-        """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
+        """note: this test requires that the 'test_schema' schema be separate and accessible by the test user"""
 
         meta1 = MetaData(testing.db)
         users = Table('users', meta1,
             Column('user_id', Integer, primary_key = True),
             Column('user_name', String(30), nullable = False),
-            schema="alt_schema"
+            schema="test_schema"
             )
 
         addresses = Table('email_addresses', meta1,
             Column('address_id', Integer, primary_key = True),
             Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
-            schema="alt_schema"
+            schema="test_schema"
         )
         meta1.create_all()
         try:
             meta2 = MetaData(testing.db)
-            addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema")
-            users = Table('users', meta2, mustexist=True, schema="alt_schema")
+            addresses = Table('email_addresses', meta2, autoload=True, schema="test_schema")
+            users = Table('users', meta2, mustexist=True, schema="test_schema")
 
             print users
             print addresses
@@ -574,12 +752,12 @@ class MiscTest(TestBase, AssertsExecutionResults):
         referer = Table("referer", meta1,
                         Column("id", Integer, primary_key=True),
                         Column("ref", Integer, ForeignKey('subject.id')),
-                        schema="alt_schema")
+                        schema="test_schema")
         meta1.create_all()
         try:
             meta2 = MetaData(testing.db)
             subject = Table("subject", meta2, autoload=True)
-            referer = Table("referer", meta2, schema="alt_schema", autoload=True)
+            referer = Table("referer", meta2, schema="test_schema", autoload=True)
             print str(subject.join(referer).onclause)
             self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause))
         finally:
@@ -589,19 +767,19 @@ class MiscTest(TestBase, AssertsExecutionResults):
         meta1 = MetaData(testing.db)
         subject = Table("subject", meta1,
                         Column("id", Integer, primary_key=True),
-                        schema='alt_schema_2'
+                        schema='test_schema_2'
                         )
 
         referer = Table("referer", meta1,
                         Column("id", Integer, primary_key=True),
-                        Column("ref", Integer, ForeignKey('alt_schema_2.subject.id')),
-                        schema="alt_schema")
+                        Column("ref", Integer, ForeignKey('test_schema_2.subject.id')),
+                        schema="test_schema")
 
         meta1.create_all()
         try:
             meta2 = MetaData(testing.db)
-            subject = Table("subject", meta2, autoload=True, schema="alt_schema_2")
-            referer = Table("referer", meta2, schema="alt_schema", autoload=True)
+            subject = Table("subject", meta2, autoload=True, schema="test_schema_2")
+            referer = Table("referer", meta2, schema="test_schema", autoload=True)
             print str(subject.join(referer).onclause)
             self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause))
         finally:
@@ -611,7 +789,7 @@ class MiscTest(TestBase, AssertsExecutionResults):
         meta = MetaData(testing.db)
         users = Table('users', meta,
             Column('id', Integer, primary_key=True),
-            Column('name', String(50)), schema='alt_schema')
+            Column('name', String(50)), schema='test_schema')
         users.create()
         try:
             users.insert().execute(id=1, name='name1')
@@ -646,15 +824,15 @@ class MiscTest(TestBase, AssertsExecutionResults):
                  user_name        VARCHAR    NOT NULL,
                  user_password    VARCHAR    NOT NULL
              );
-            """, None)
+            """)
 
             t = Table("speedy_users", meta, autoload=True)
             r = t.insert().execute(user_name='user', user_password='lala')
-            assert r.last_inserted_ids() == [1]
+            assert r.inserted_primary_key == [1]
             l = t.select().execute().fetchall()
             assert l == [(1, 'user', 'lala')]
         finally:
-            testing.db.execute("drop table speedy_users", None)
+            testing.db.execute("drop table speedy_users")
 
     @testing.emits_warning()
     def test_index_reflection(self):
@@ -676,10 +854,10 @@ class MiscTest(TestBase, AssertsExecutionResults):
         
         testing.db.execute("""
           create index idx1 on party ((id || name))
-        """, None
+        """) 
         testing.db.execute("""
           create unique index idx2 on party (id) where name = 'test'
-        """, None)
+        """)
         
         testing.db.execute("""
             create index idx3 on party using btree
@@ -713,35 +891,42 @@ class MiscTest(TestBase, AssertsExecutionResults):
             warnings.warn = capture_warnings._orig_showwarning
             m1.drop_all()
 
-    def test_create_partial_index(self):
-        tbl = Table('testtbl', MetaData(), Column('data',Integer))
-        idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
-
-        executed_sql = []
-        mock_strategy = MockEngineStrategy()
-        mock_conn = mock_strategy.create('postgres://', executed_sql.append)
+    def test_set_isolation_level(self):
+        """Test setting the isolation level with create_engine"""
+        eng = create_engine(testing.db.url)
+        eq_(
+            eng.execute("show transaction isolation level").scalar(),
+            'read committed')
+        eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE")
+        eq_(
+            eng.execute("show transaction isolation level").scalar(),
+            'serializable')
+        eng = create_engine(testing.db.url, isolation_level="FOO")
 
-        idx.create(mock_conn)
+        if testing.db.driver == 'zxjdbc':
+            exception_cls = eng.dialect.dbapi.Error
+        else:
+            exception_cls = eng.dialect.dbapi.ProgrammingError
+        assert_raises(exception_cls, eng.execute, "show transaction isolation level")
 
-        assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10']
 
 class TimezoneTest(TestBase, AssertsExecutionResults):
     """Test timezone-aware datetimes.
 
-    psycopg will return a datetime with a tzinfo attached to it, if postgres
+    psycopg will return a datetime with a tzinfo attached to it, if postgresql
     returns it.  python then will not let you compare a datetime with a tzinfo
     to a datetime that doesnt have one.  this test illustrates two ways to
     have datetime types with and without timezone info.
     """
 
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql'
 
     @classmethod
     def setup_class(cls):
         global tztable, notztable, metadata
         metadata = MetaData(testing.db)
 
-        # current_timestamp() in postgres is assumed to return TIMESTAMP WITH TIMEZONE
+        # current_timestamp() in postgresql is assumed to return TIMESTAMP WITH TIMEZONE
         tztable = Table('tztable', metadata,
             Column("id", Integer, primary_key=True),
             Column("date", DateTime(timezone=True), onupdate=func.current_timestamp()),
@@ -762,17 +947,17 @@ class TimezoneTest(TestBase, AssertsExecutionResults):
         somedate = testing.db.connect().scalar(func.current_timestamp().select())
         tztable.insert().execute(id=1, name='row1', date=somedate)
         c = tztable.update(tztable.c.id==1).execute(name='newname')
-        print tztable.select(tztable.c.id==1).execute().fetchone()
+        print tztable.select(tztable.c.id==1).execute().first()
 
     def test_without_timezone(self):
         # get a date without a tzinfo
         somedate = datetime.datetime(2005, 10,20, 11, 52, 00)
         notztable.insert().execute(id=1, name='row1', date=somedate)
         c = notztable.update(notztable.c.id==1).execute(name='newname')
-        print notztable.select(tztable.c.id==1).execute().fetchone()
+        print notztable.select(tztable.c.id==1).execute().first()
 
 class ArrayTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql'
 
     @classmethod
     def setup_class(cls):
@@ -781,10 +966,14 @@ class ArrayTest(TestBase, AssertsExecutionResults):
 
         arrtable = Table('arrtable', metadata,
             Column('id', Integer, primary_key=True),
-            Column('intarr', postgres.PGArray(Integer)),
-            Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False)
+            Column('intarr', postgresql.PGArray(Integer)),
+            Column('strarr', postgresql.PGArray(String(convert_unicode=True)), nullable=False)
         )
         metadata.create_all()
+        
+    def teardown(self):
+        arrtable.delete().execute()
+        
     @classmethod
     def teardown_class(cls):
         metadata.drop_all()
@@ -792,34 +981,38 @@ class ArrayTest(TestBase, AssertsExecutionResults):
     def test_reflect_array_column(self):
         metadata2 = MetaData(testing.db)
         tbl = Table('arrtable', metadata2, autoload=True)
-        assert isinstance(tbl.c.intarr.type, postgres.PGArray)
-        assert isinstance(tbl.c.strarr.type, postgres.PGArray)
+        assert isinstance(tbl.c.intarr.type, postgresql.PGArray)
+        assert isinstance(tbl.c.strarr.type, postgresql.PGArray)
         assert isinstance(tbl.c.intarr.type.item_type, Integer)
         assert isinstance(tbl.c.strarr.type.item_type, String)
 
+    @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
     def test_insert_array(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         results = arrtable.select().execute().fetchall()
         eq_(len(results), 1)
         eq_(results[0]['intarr'], [1,2,3])
         eq_(results[0]['strarr'], ['abc','def'])
-        arrtable.delete().execute()
 
+    @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+    @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
     def test_array_where(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
         results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
         eq_(len(results), 1)
         eq_(results[0]['intarr'], [1,2,3])
-        arrtable.delete().execute()
 
+    @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+    @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
     def test_array_concat(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
         eq_(len(results), 1)
         eq_(results[0][0], [1,2,3,4,5,6])
-        arrtable.delete().execute()
 
+    @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+    @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
     def test_array_subtype_resultprocessor(self):
         arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']])
         arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6'])
@@ -827,13 +1020,14 @@ class ArrayTest(TestBase, AssertsExecutionResults):
         eq_(len(results), 2)
         eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
         eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
-        arrtable.delete().execute()
 
+    @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+    @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
     def test_array_mutability(self):
         class Foo(object): pass
         footable = Table('foo', metadata,
             Column('id', Integer, primary_key=True),
-            Column('intarr', postgres.PGArray(Integer), nullable=True)
+            Column('intarr', postgresql.PGArray(Integer), nullable=True)
         )
         mapper(Foo, footable)
         metadata.create_all()
@@ -870,19 +1064,19 @@ class ArrayTest(TestBase, AssertsExecutionResults):
         sess.add(foo)
         sess.flush()
 
-class TimeStampTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
-    
-    @testing.uses_deprecated()
+class TimestampTest(TestBase, AssertsExecutionResults):
+    __only_on__ = 'postgresql'
+
     def test_timestamp(self):
         engine = testing.db
         connection = engine.connect()
-        s = select([func.TIMESTAMP("12/25/07").label("ts")])
-        result = connection.execute(s).fetchone()
+        
+        s = select(["timestamp '2007-12-25'"])
+        result = connection.execute(s).first()
         eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0))
 
 class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql+psycopg2'
 
     @classmethod
     def setup_class(cls):
@@ -927,8 +1121,8 @@ class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
 class SpecialTypesTest(TestBase, ComparesTables):
     """test DDL and reflection of PG-specific types """
     
-    __only_on__ = 'postgres'
-    __excluded_on__ = (('postgres', '<', (8, 3, 0)),)
+    __only_on__ = 'postgresql'
+    __excluded_on__ = (('postgresql', '<', (8, 3, 0)),)
     
     @classmethod
     def setup_class(cls):
@@ -936,11 +1130,11 @@ class SpecialTypesTest(TestBase, ComparesTables):
         metadata = MetaData(testing.db)
         
         table = Table('sometable', metadata,
-            Column('id', postgres.PGUuid, primary_key=True),
-            Column('flag', postgres.PGBit),
-            Column('addr', postgres.PGInet),
-            Column('addr2', postgres.PGMacAddr),
-            Column('addr3', postgres.PGCidr)
+            Column('id', postgresql.PGUuid, primary_key=True),
+            Column('flag', postgresql.PGBit),
+            Column('addr', postgresql.PGInet),
+            Column('addr2', postgresql.PGMacAddr),
+            Column('addr3', postgresql.PGCidr)
         )
         
         metadata.create_all()
@@ -957,8 +1151,8 @@ class SpecialTypesTest(TestBase, ComparesTables):
         
 
 class MatchTest(TestBase, AssertsCompiledSQL):
-    __only_on__ = 'postgres'
-    __excluded_on__ = (('postgres', '<', (8, 3, 0)),)
+    __only_on__ = 'postgresql'
+    __excluded_on__ = (('postgresql', '<', (8, 3, 0)),)
 
     @classmethod
     def setup_class(cls):
@@ -992,9 +1186,16 @@ class MatchTest(TestBase, AssertsCompiledSQL):
     def teardown_class(cls):
         metadata.drop_all()
 
-    def test_expression(self):
+    @testing.fails_on('postgresql+pg8000', 'uses positional')
+    @testing.fails_on('postgresql+zxjdbc', 'uses qmark')
+    def test_expression_pyformat(self):
         self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%(title_1)s)")
 
+    @testing.fails_on('postgresql+psycopg2', 'uses pyformat')
+    @testing.fails_on('postgresql+zxjdbc', 'uses qmark')
+    def test_expression_positional(self):
+        self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%s)")
+
     def test_simple_match(self):
         results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
         eq_([2, 5], [r.id for r in results])
index eb4581e20fcca7aa1249183ae3e757668d535038..448ee947c0be6d0dd0d6d1f4ca2d81e378ed423a 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import datetime
 from sqlalchemy import *
 from sqlalchemy import exc, sql
-from sqlalchemy.databases import sqlite
+from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect
 from sqlalchemy.test import *
 
 
@@ -19,7 +19,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
         meta = MetaData(testing.db)
         t = Table('bool_table', meta,
                   Column('id', Integer, primary_key=True),
-                  Column('boo', sqlite.SLBoolean))
+                  Column('boo', Boolean))
 
         try:
             meta.create_all()
@@ -39,7 +39,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
     def test_time_microseconds(self):
         dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125)  # 125 usec
         eq_(str(dt), '2008-06-27 12:00:00.000125')
-        sldt = sqlite.SLDateTime()
+        sldt = sqlite._SLDateTime()
         bp = sldt.bind_processor(None)
         eq_(bp(dt), '2008-06-27 12:00:00.000125')
         
@@ -69,59 +69,44 @@ class TestTypes(TestBase, AssertsExecutionResults):
             bindproc = t.dialect_impl(dialect).bind_processor(dialect)
             assert not bindproc or isinstance(bindproc(u"some string"), unicode)
         
-    @testing.uses_deprecated('Using String type with no length')
     def test_type_reflection(self):
         # (ask_for, roundtripped_as_if_different)
-        specs = [( String(), sqlite.SLString(), ),
-                 ( String(1), sqlite.SLString(1), ),
-                 ( String(3), sqlite.SLString(3), ),
-                 ( Text(), sqlite.SLText(), ),
-                 ( Unicode(), sqlite.SLString(), ),
-                 ( Unicode(1), sqlite.SLString(1), ),
-                 ( Unicode(3), sqlite.SLString(3), ),
-                 ( UnicodeText(), sqlite.SLText(), ),
-                 ( CLOB, sqlite.SLText(), ),
-                 ( sqlite.SLChar(1), ),
-                 ( CHAR(3), sqlite.SLChar(3), ),
-                 ( NCHAR(2), sqlite.SLChar(2), ),
-                 ( SmallInteger(), sqlite.SLSmallInteger(), ),
-                 ( sqlite.SLSmallInteger(), ),
-                 ( Binary(3), sqlite.SLBinary(), ),
-                 ( Binary(), sqlite.SLBinary() ),
-                 ( sqlite.SLBinary(3), sqlite.SLBinary(), ),
-                 ( NUMERIC, sqlite.SLNumeric(), ),
-                 ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ),
-                 ( Numeric, sqlite.SLNumeric(), ),
-                 ( Numeric(10, 2), sqlite.SLNumeric(10, 2), ),
-                 ( DECIMAL, sqlite.SLNumeric(), ),
-                 ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ),
-                 ( Float, sqlite.SLFloat(), ),
-                 ( sqlite.SLNumeric(), ),
-                 ( INT, sqlite.SLInteger(), ),
-                 ( Integer, sqlite.SLInteger(), ),
-                 ( sqlite.SLInteger(), ),
-                 ( TIMESTAMP, sqlite.SLDateTime(), ),
-                 ( DATETIME, sqlite.SLDateTime(), ),
-                 ( DateTime, sqlite.SLDateTime(), ),
-                 ( sqlite.SLDateTime(), ),
-                 ( DATE, sqlite.SLDate(), ),
-                 ( Date, sqlite.SLDate(), ),
-                 ( sqlite.SLDate(), ),
-                 ( TIME, sqlite.SLTime(), ),
-                 ( Time, sqlite.SLTime(), ),
-                 ( sqlite.SLTime(), ),
-                 ( BOOLEAN, sqlite.SLBoolean(), ),
-                 ( Boolean, sqlite.SLBoolean(), ),
-                 ( sqlite.SLBoolean(), ),
+        specs = [( String(), String(), ),
+                 ( String(1), String(1), ),
+                 ( String(3), String(3), ),
+                 ( Text(), Text(), ),
+                 ( Unicode(), String(), ),
+                 ( Unicode(1), String(1), ),
+                 ( Unicode(3), String(3), ),
+                 ( UnicodeText(), Text(), ),
+                 ( CHAR(1), ),
+                 ( CHAR(3), CHAR(3), ),
+                 ( NUMERIC, NUMERIC(), ),
+                 ( NUMERIC(10,2), NUMERIC(10,2), ),
+                 ( Numeric, NUMERIC(), ),
+                 ( Numeric(10, 2), NUMERIC(10, 2), ),
+                 ( DECIMAL, DECIMAL(), ),
+                 ( DECIMAL(10, 2), DECIMAL(10, 2), ),
+                 ( Float, Float(), ),
+                 ( NUMERIC(), ),
+                 ( TIMESTAMP, TIMESTAMP(), ),
+                 ( DATETIME, DATETIME(), ),
+                 ( DateTime, DateTime(), ),
+                 ( DateTime(), ),
+                 ( DATE, DATE(), ),
+                 ( Date, Date(), ),
+                 ( TIME, TIME(), ),
+                 ( Time, Time(), ),
+                 ( BOOLEAN, BOOLEAN(), ),
+                 ( Boolean, Boolean(), ),
                  ]
         columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
 
         db = testing.db
         m = MetaData(db)
         t_table = Table('types', m, *columns)
+        m.create_all()
         try:
-            m.create_all()
-
             m2 = MetaData(db)
             rt = Table('types', m2, autoload=True)
             try:
@@ -131,7 +116,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
                 expected = [len(c) > 1 and c[1] or c[0] for c in specs]
                 for table in rt, rv:
                     for i, reflected in enumerate(table.c):
-                        assert isinstance(reflected.type, type(expected[i])), type(expected[i])
+                        assert isinstance(reflected.type, type(expected[i])), "%d: %r" % (i, type(expected[i]))
             finally:
                 db.execute('DROP VIEW types_v')
         finally:
@@ -163,7 +148,7 @@ class TestDefaults(TestBase, AssertsExecutionResults):
             rt = Table('t_defaults', m2, autoload=True)
             expected = [c[1] for c in specs]
             for i, reflected in enumerate(rt.c):
-                eq_(reflected.server_default.arg.text, expected[i])
+                eq_(str(reflected.server_default.arg), expected[i])
         finally:
             m.drop_all()
 
@@ -173,7 +158,7 @@ class TestDefaults(TestBase, AssertsExecutionResults):
         db = testing.db
         m = MetaData(db)
 
-        expected = ["'my_default'", '0']
+        expected = ["my_default", '0']
         table = """CREATE TABLE r_defaults (
             data VARCHAR(40) DEFAULT 'my_default',
             val INTEGER NOT NULL DEFAULT 0
@@ -184,7 +169,7 @@ class TestDefaults(TestBase, AssertsExecutionResults):
 
             rt = Table('r_defaults', m, autoload=True)
             for i, reflected in enumerate(rt.c):
-                eq_(reflected.server_default.arg.text, expected[i])
+                eq_(str(reflected.server_default.arg), expected[i])
         finally:
             db.execute("DROP TABLE r_defaults")
 
@@ -247,24 +232,24 @@ class DialectTest(TestBase, AssertsExecutionResults):
     def test_attached_as_schema(self):
         cx = testing.db.connect()
         try:
-            cx.execute('ATTACH DATABASE ":memory:" AS  alt_schema')
+            cx.execute('ATTACH DATABASE ":memory:" AS  test_schema')
             dialect = cx.dialect
-            assert dialect.table_names(cx, 'alt_schema') == []
+            assert dialect.table_names(cx, 'test_schema') == []
 
             meta = MetaData(cx)
             Table('created', meta, Column('id', Integer),
-                  schema='alt_schema')
+                  schema='test_schema')
             alt_master = Table('sqlite_master', meta, autoload=True,
-                               schema='alt_schema')
+                               schema='test_schema')
             meta.create_all(cx)
 
-            eq_(dialect.table_names(cx, 'alt_schema'),
+            eq_(dialect.table_names(cx, 'test_schema'),
                               ['created'])
             assert len(alt_master.c) > 0
 
             meta.clear()
             reflected = Table('created', meta, autoload=True,
-                              schema='alt_schema')
+                              schema='test_schema')
             assert len(reflected.c) == 1
 
             cx.execute(reflected.insert(), dict(id=1))
@@ -282,9 +267,9 @@ class DialectTest(TestBase, AssertsExecutionResults):
             # note that sqlite_master is cleared, above
             meta.drop_all()
 
-            assert dialect.table_names(cx, 'alt_schema') == []
+            assert dialect.table_names(cx, 'test_schema') == []
         finally:
-            cx.execute('DETACH DATABASE alt_schema')
+            cx.execute('DETACH DATABASE test_schema')
 
     @testing.exclude('sqlite', '<', (2, 6), 'no database support')
     def test_temp_table_reflection(self):
@@ -305,6 +290,20 @@ class DialectTest(TestBase, AssertsExecutionResults):
                 pass
             raise
 
+    def test_set_isolation_level(self):
+        """Test setting the read uncommitted/serializable levels"""
+        eng = create_engine(testing.db.url)
+        eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0)
+
+        eng = create_engine(testing.db.url, isolation_level="READ UNCOMMITTED")
+        eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 1)
+
+        eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE")
+        eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0)
+
+        assert_raises(exc.ArgumentError, create_engine, testing.db.url,
+            isolation_level="FOO")
+
 
 class SQLTest(TestBase, AssertsCompiledSQL):
     """Tests SQLite-dialect specific compilation."""
index 7fd3009bca383f5055f82c1472c49ed3ec94f258..1122f1632fbdb4cb4f8a4724b29a38f10c3dff8b 100644 (file)
@@ -121,7 +121,7 @@ class BindTest(testing.TestBase):
                 table = Table('test_table', metadata,
                     Column('foo', Integer))
 
-                metadata.connect(bind)
+                metadata.bind = bind
 
                 assert metadata.bind is table.bind is bind
                 metadata.create_all()
@@ -199,7 +199,7 @@ class BindTest(testing.TestBase):
                     try:
                         e = elem(bind=bind)
                         assert e.bind is bind
-                        e.execute()
+                        e.execute().close()
                     finally:
                         if isinstance(bind, engine.Connection):
                             bind.close()
index 5716006d93c54359ffa620e35e15d60bb4b25ffb..434a5d873c7c0c6b0a85675c4ea96417f433cfcb 100644 (file)
@@ -1,12 +1,13 @@
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-from sqlalchemy.schema import DDL
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint
 from sqlalchemy import create_engine
 from sqlalchemy import MetaData, Integer, String
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
 import sqlalchemy as tsa
 from sqlalchemy.test import TestBase, testing, engines
-
+from sqlalchemy.test.testing import AssertsCompiledSQL
+from nose import SkipTest
 
 class DDLEventTest(TestBase):
     class Canary(object):
@@ -15,25 +16,25 @@ class DDLEventTest(TestBase):
             self.schema_item = schema_item
             self.bind = bind
 
-        def before_create(self, action, schema_item, bind):
+        def before_create(self, action, schema_item, bind, **kw):
             assert self.state is None
             assert schema_item is self.schema_item
             assert bind is self.bind
             self.state = action
 
-        def after_create(self, action, schema_item, bind):
+        def after_create(self, action, schema_item, bind, **kw):
             assert self.state in ('before-create', 'skipped')
             assert schema_item is self.schema_item
             assert bind is self.bind
             self.state = action
 
-        def before_drop(self, action, schema_item, bind):
+        def before_drop(self, action, schema_item, bind, **kw):
             assert self.state is None
             assert schema_item is self.schema_item
             assert bind is self.bind
             self.state = action
 
-        def after_drop(self, action, schema_item, bind):
+        def after_drop(self, action, schema_item, bind, **kw):
             assert self.state in ('before-drop', 'skipped')
             assert schema_item is self.schema_item
             assert bind is self.bind
@@ -232,7 +233,33 @@ class DDLExecutionTest(TestBase):
         assert 'klptzyxm' not in strings
         assert 'xyzzy' in strings
         assert 'fnord' in strings
-
+    
+    def test_conditional_constraint(self):
+        metadata, users, engine = self.metadata, self.users, self.engine
+        nonpg_mock = engines.mock_engine(dialect_name='sqlite')
+        pg_mock = engines.mock_engine(dialect_name='postgresql')
+        
+        constraint = CheckConstraint('a < b',name="my_test_constraint", table=users)
+
+        # by placing the constraint in an Add/Drop construct,
+        # the 'inline_ddl' flag is set to False
+        AddConstraint(constraint, on='postgresql').execute_at("after-create", users)
+        DropConstraint(constraint, on='postgresql').execute_at("before-drop", users)
+        
+        metadata.create_all(bind=nonpg_mock)
+        strings = " ".join(str(x) for x in nonpg_mock.mock)
+        assert "my_test_constraint" not in strings
+        metadata.drop_all(bind=nonpg_mock)
+        strings = " ".join(str(x) for x in nonpg_mock.mock)
+        assert "my_test_constraint" not in strings
+
+        metadata.create_all(bind=pg_mock)
+        strings = " ".join(str(x) for x in pg_mock.mock)
+        assert "my_test_constraint" in strings
+        metadata.drop_all(bind=pg_mock)
+        strings = " ".join(str(x) for x in pg_mock.mock)
+        assert "my_test_constraint" in strings
+        
     def test_metadata(self):
         metadata, engine = self.metadata, self.engine
         DDL('mxyzptlk').execute_at('before-create', metadata)
@@ -255,7 +282,10 @@ class DDLExecutionTest(TestBase):
         assert 'fnord' in strings
 
     def test_ddl_execute(self):
-        engine = create_engine('sqlite:///')
+        try:
+            engine = create_engine('sqlite:///')
+        except ImportError:
+            raise SkipTest('Requires sqlite')
         cx = engine.connect()
         table = self.users
         ddl = DDL('SELECT 1')
@@ -286,7 +316,7 @@ class DDLExecutionTest(TestBase):
                 r = eval(py)
                 assert list(r) == [(1,)], py
 
-class DDLTest(TestBase):
+class DDLTest(TestBase, AssertsCompiledSQL):
     def mock_engine(self):
         executor = lambda *a, **kw: None
         engine = create_engine(testing.db.name + '://',
@@ -297,7 +327,6 @@ class DDLTest(TestBase):
 
     def test_tokens(self):
         m = MetaData()
-        bind = self.mock_engine()
         sane_alone = Table('t', m, Column('id', Integer))
         sane_schema = Table('t', m, Column('id', Integer), schema='s')
         insane_alone = Table('t t', m, Column('id', Integer))
@@ -305,20 +334,21 @@ class DDLTest(TestBase):
 
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
 
-        eq_(ddl._expand(sane_alone, bind), '-t-t')
-        eq_(ddl._expand(sane_schema, bind), 's-t-s.t')
-        eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
-        eq_(ddl._expand(insane_schema, bind),
-                          '"s s"-"t t"-"s s"."t t"')
+        dialect = self.mock_engine().dialect
+        self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect)
+        self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect)
+        self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect)
+        self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect)
 
         # overrides are used piece-meal and verbatim.
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
                   context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
-        eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
-        eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
-        eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
-        eq_(ddl._expand(insane_schema, bind),
-                          'S S-T T-"s s"."t t"-b')
+
+        self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect)
+        self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect)
+        self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect)
+        self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect)
+
     def test_filter(self):
         cx = self.mock_engine()
 
index 08bf80fe2f26c8758137c820e5f0eb72f6ed0011..4783c55080dad5b4809297ad28ad0e75fd170155 100644 (file)
@@ -15,18 +15,20 @@ class ExecuteTest(TestBase):
         global users, metadata
         metadata = MetaData(testing.db)
         users = Table('users', metadata,
-            Column('user_id', INT, primary_key = True),
+            Column('user_id', INT, primary_key = True, test_needs_autoincrement=True),
             Column('user_name', VARCHAR(20)),
         )
         metadata.create_all()
 
+    @engines.close_first
     def teardown(self):
         testing.db.connect().execute(users.delete())
+        
     @classmethod
     def teardown_class(cls):
         metadata.drop_all()
 
-    @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite')
+    @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc')
     def test_raw_qmark(self):
         for conn in (testing.db, testing.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
@@ -38,7 +40,8 @@ class ExecuteTest(TestBase):
             assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
             conn.execute("delete from users")
 
-    @testing.fails_on_everything_except('mysql', 'postgres')
+    @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql')
+    @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported')
     # some psycopg2 versions bomb this.
     def test_raw_sprintf(self):
         for conn in (testing.db, testing.db.connect()):
@@ -52,8 +55,8 @@ class ExecuteTest(TestBase):
 
     # pyformat is supported for mysql, but skipping because a few driver
     # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
-    @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky')
-    @testing.fails_on_everything_except('postgres')
+    @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky')
+    @testing.fails_on_everything_except('postgresql+psycopg2')
     def test_raw_python(self):
         for conn in (testing.db, testing.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
@@ -63,7 +66,7 @@ class ExecuteTest(TestBase):
             assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
             conn.execute("delete from users")
 
-    @testing.fails_on_everything_except('sqlite', 'oracle')
+    @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle')
     def test_raw_named(self):
         for conn in (testing.db, testing.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
@@ -81,11 +84,12 @@ class ExecuteTest(TestBase):
             except tsa.exc.DBAPIError:
                 assert True
 
-    @testing.fails_on('mssql', 'rowcount returns -1')
     def test_empty_insert(self):
         """test that execute() interprets [] as a list with no params"""
         result = testing.db.execute(users.insert().values(user_name=bindparam('name')), [])
-        eq_(result.rowcount, 1)
+        eq_(testing.db.execute(users.select()).fetchall(), [
+            (1, None)
+        ])
 
 class ProxyConnectionTest(TestBase):
     @testing.fails_on('firebird', 'Data type unknown')
@@ -102,6 +106,7 @@ class ProxyConnectionTest(TestBase):
                 return execute(clauseelement, *multiparams, **params)
 
             def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+                print "CE", statement, parameters
                 cursor_stmts.append(
                     (statement, parameters, None)
                 )
@@ -118,8 +123,8 @@ class ProxyConnectionTest(TestBase):
                         break
 
         for engine in (
-            engines.testing_engine(options=dict(proxy=MyProxy())),
-            engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal'))
+            engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())),
+            engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal'))
         ):
             m = MetaData(engine)
 
@@ -131,6 +136,7 @@ class ProxyConnectionTest(TestBase):
                 t1.insert().execute(c1=6)
                 assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
             finally:
+                pass
                 m.drop_all()
             
             engine.dispose()
@@ -143,14 +149,14 @@ class ProxyConnectionTest(TestBase):
                 ("DROP TABLE t1", {}, None)
             ]
 
-            if engine.dialect.preexecute_pk_sequences:
+            if True: # or engine.dialect.preexecute_pk_sequences:
                 cursor = [
-                    ("CREATE TABLE t1", {}, None),
+                    ("CREATE TABLE t1", {}, ()),
                     ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
                     ("SELECT lower", {'lower_2':'Foo'}, ['Foo']),
                     ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']),
-                    ("select * from t1", {}, None),
-                    ("DROP TABLE t1", {}, None)
+                    ("select * from t1", {}, ()),
+                    ("DROP TABLE t1", {}, ())
                 ]
             else:
                 cursor = [
index ca4fbaa48a525087666d4a442ede79e6de397c58..784a7b9ce619a8cbdd3b6726caef33ab9c34eb2b 100644 (file)
@@ -1,11 +1,11 @@
 from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import pickle
-from sqlalchemy import MetaData
-from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey
+from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey, MetaData
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
+from sqlalchemy import schema
 import sqlalchemy as tsa
-from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+from sqlalchemy.test import TestBase, ComparesTables, AssertsCompiledSQL, testing, engines
 from sqlalchemy.test.testing import eq_
 
 class MetaDataTest(TestBase, ComparesTables):
@@ -83,7 +83,7 @@ class MetaDataTest(TestBase, ComparesTables):
 
         meta.create_all(testing.db)
         try:
-            for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)):
+            for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)):
                 table_c, table2_c = test()
                 self.assert_tables_equal(table, table_c)
                 self.assert_tables_equal(table2, table2_c)
@@ -143,29 +143,30 @@ class MetaDataTest(TestBase, ComparesTables):
                           MetaData(testing.db), autoload=True)
 
 
-class TableOptionsTest(TestBase):
-    def setup(self):
-        self.engine = engines.mock_engine()
-        self.metadata = MetaData(self.engine)
-
+class TableOptionsTest(TestBase, AssertsCompiledSQL):
     def test_prefixes(self):
-        table1 = Table("temporary_table_1", self.metadata,
+        table1 = Table("temporary_table_1", MetaData(),
                       Column("col1", Integer),
                       prefixes = ["TEMPORARY"])
-        table1.create()
-        assert [str(x) for x in self.engine.mock if 'CREATE TEMPORARY TABLE' in str(x)]
-        del self.engine.mock[:]
-        table2 = Table("temporary_table_2", self.metadata,
+                      
+        self.assert_compile(
+            schema.CreateTable(table1), 
+            "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)"
+        )
+
+        table2 = Table("temporary_table_2", MetaData(),
                       Column("col1", Integer),
                       prefixes = ["VIRTUAL"])
-        table2.create()
-        assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)]
+        self.assert_compile(
+          schema.CreateTable(table2), 
+          "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)"
+        )
 
     def test_table_info(self):
-
-        t1 = Table('foo', self.metadata, info={'x':'y'})
-        t2 = Table('bar', self.metadata, info={})
-        t3 = Table('bat', self.metadata)
+        metadata = MetaData()
+        t1 = Table('foo', metadata, info={'x':'y'})
+        t2 = Table('bar', metadata, info={})
+        t3 = Table('bat', metadata)
         assert t1.info == {'x':'y'}
         assert t2.info == {}
         assert t3.info == {}
index 6b7ac37b20f7d14ae4fe46c36679268380be3698..90c0969bed288c8e05fef9b6e86afa10c24a5304 100644 (file)
@@ -1,4 +1,6 @@
-import ConfigParser, StringIO
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import ConfigParser
+import StringIO
 import sqlalchemy.engine.url as url
 from sqlalchemy import create_engine, engine_from_config
 import sqlalchemy as tsa
@@ -28,8 +30,6 @@ class ParseConnectTest(TestBase):
             'dbtype://username:apples%2Foranges@hostspec/mydatabase',
         ):
             u = url.make_url(text)
-            print u, text
-            print "username=", u.username, "password=", u.password,  "database=", u.database, "host=", u.host
             assert u.drivername == 'dbtype'
             assert u.username == 'username' or u.username is None
             assert u.password == 'password' or u.password == 'apples/oranges' or u.password is None
@@ -41,21 +41,28 @@ class CreateEngineTest(TestBase):
     def test_connect_query(self):
         dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue')
 
-        # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
-        e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi)
+        e = create_engine(
+                'postgresql://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', 
+                module=dbapi,
+                _initialize=False
+                )
         c = e.connect()
 
     def test_kwargs(self):
         dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
 
-        # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
-        e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi)
+        e = create_engine(
+                'postgresql://scott:tiger@somehost/test?fooz=somevalue', 
+                connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, 
+                module=dbapi,
+                _initialize=False
+                )
         c = e.connect()
 
     def test_coerce_config(self):
         raw = r"""
 [prefixed]
-sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue
+sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue
 sqlalchemy.convert_unicode=0
 sqlalchemy.echo=false
 sqlalchemy.echo_pool=1
@@ -65,7 +72,7 @@ sqlalchemy.pool_size=2
 sqlalchemy.pool_threadlocal=1
 sqlalchemy.pool_timeout=10
 [plain]
-url=postgres://scott:tiger@somehost/test?fooz=somevalue
+url=postgresql://scott:tiger@somehost/test?fooz=somevalue
 convert_unicode=0
 echo=0
 echo_pool=1
@@ -79,7 +86,7 @@ pool_timeout=10
         ini.readfp(StringIO.StringIO(raw))
 
         expected = {
-            'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue',
+            'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue',
             'convert_unicode': 0,
             'echo': False,
             'echo_pool': True,
@@ -97,17 +104,17 @@ pool_timeout=10
         self.assert_(tsa.engine._coerce_config(plain, '') == expected)
 
     def test_engine_from_config(self):
-        dbapi = MockDBAPI()
+        dbapi = mock_dbapi
 
         config = {
-            'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue',
+            'sqlalchemy.url':'postgresql://scott:tiger@somehost/test?fooz=somevalue',
             'sqlalchemy.pool_recycle':'50',
             'sqlalchemy.echo':'true'
         }
 
         e = engine_from_config(config, module=dbapi)
         assert e.pool._recycle == 50
-        assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue')
+        assert e.url == url.make_url('postgresql://scott:tiger@somehost/test?fooz=somevalue')
         assert e.echo is True
 
     def test_custom(self):
@@ -116,109 +123,77 @@ pool_timeout=10
         def connect():
             return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'})
 
-        # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
-        e = create_engine('postgres://', creator=connect, module=dbapi)
+        # start the postgresql dialect, but put our mock DBAPI as the module instead of psycopg
+        e = create_engine('postgresql://', creator=connect, module=dbapi, _initialize=False)
         c = e.connect()
 
     def test_recycle(self):
         dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
-        e = create_engine('postgres://', pool_recycle=472, module=dbapi)
+        e = create_engine('postgresql://', pool_recycle=472, module=dbapi, _initialize=False)
         assert e.pool._recycle == 472
 
     def test_badargs(self):
-        # good arg, use MockDBAPI to prevent oracle import errors
-        e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
-
-        try:
-            e = create_engine("foobar://", module=MockDBAPI())
-            assert False
-        except ImportError:
-            assert True
+        assert_raises(ImportError, create_engine, "foobar://", module=mock_dbapi)
 
         # bad arg
-        try:
-            e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
+        assert_raises(TypeError, create_engine, 'postgresql://', use_ansi=True, module=mock_dbapi)
 
         # bad arg
-        try:
-            e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
+        assert_raises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=mock_dbapi)
 
-        try:
-            e = create_engine('postgres://', lala=5, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
+        assert_raises(TypeError, create_engine, 'postgresql://', lala=5, module=mock_dbapi)
 
-        try:
-            e = create_engine('sqlite://', lala=5)
-            assert False
-        except TypeError:
-            assert True
+        assert_raises(TypeError, create_engine,'sqlite://', lala=5, module=mock_sqlite_dbapi)
 
-        try:
-            e = create_engine('mysql://', use_unicode=True, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
-
-        try:
-            # sqlite uses SingletonThreadPool which doesnt have max_overflow
-            e = create_engine('sqlite://', max_overflow=5)
-            assert False
-        except TypeError:
-            assert True
+        assert_raises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=mock_dbapi)
 
-        e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
+        # sqlite uses SingletonThreadPool which doesnt have max_overflow
+        assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=5,
+                      module=mock_sqlite_dbapi)
 
-        e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
         try:
-            c = e.connect()
-            assert False
-        except tsa.exc.DBAPIError:
-            assert True
+            e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+        except ImportError:
+            # no sqlite
+            pass
+        else:
+            # raises DBAPIerror due to use_unicode not a sqlite arg
+            assert_raises(tsa.exc.DBAPIError, e.connect)
 
     def test_urlattr(self):
         """test the url attribute on ``Engine``."""
 
-        e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI())
+        e = create_engine('mysql://scott:tiger@localhost/test', module=mock_dbapi, _initialize=False)
         u = url.make_url('mysql://scott:tiger@localhost/test')
-        e2 = create_engine(u, module=MockDBAPI())
+        e2 = create_engine(u, module=mock_dbapi, _initialize=False)
         assert e.url.drivername == e2.url.drivername == 'mysql'
         assert e.url.username == e2.url.username == 'scott'
         assert e2.url is u
 
     def test_poolargs(self):
         """test that connection pool args make it thru"""
-        e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI())
+        e = create_engine('postgresql://', creator=None, pool_recycle=50, echo_pool=None, module=mock_dbapi, _initialize=False)
         assert e.pool._recycle == 50
 
         # these args work for QueuePool
-        e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI())
+        e = create_engine('postgresql://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=mock_dbapi)
 
-        try:
-            # but not SingletonThreadPool
-            e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool)
-            assert False
-        except TypeError:
-            assert True
+        # but not SingletonThreadPool
+        assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=8, pool_timeout=60,
+                      poolclass=tsa.pool.SingletonThreadPool, module=mock_sqlite_dbapi)
 
 class MockDBAPI(object):
     def __init__(self, **kwargs):
         self.kwargs = kwargs
         self.paramstyle = 'named'
-    def connect(self, **kwargs):
-        print kwargs, self.kwargs
+    def connect(self, *args, **kwargs):
         for k in self.kwargs:
             assert k in kwargs, "key %s not present in dictionary" % k
             assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k])
         return MockConnection()
 class MockConnection(object):
+    def get_server_info(self):
+        return "5.0"
     def close(self):
         pass
     def cursor(self):
@@ -227,4 +202,6 @@ class MockCursor(object):
     def close(self):
         pass
 mock_dbapi = MockDBAPI()
-
+mock_sqlite_dbapi = msd = MockDBAPI()
+msd.version_info = msd.sqlite_version_info = (99, 9, 9)
+msd.sqlite_version = '99.9.9'
index d135ad337acc70452024450d6b2640b8e4cfce02..68637281e1462c2a69aac712e5962a7df378920e 100644 (file)
@@ -1,7 +1,8 @@
-import threading, time, gc
-from sqlalchemy import pool, interfaces
+import threading, time
+from sqlalchemy import pool, interfaces, create_engine, select
 import sqlalchemy as tsa
-from sqlalchemy.test import TestBase
+from sqlalchemy.test import TestBase, testing
+from sqlalchemy.test.util import gc_collect, lazy_gc
 
 
 mcid = 1
@@ -51,7 +52,6 @@ class PoolTest(PoolTestBase):
         connection2 = manager.connect('foo.db')
         connection3 = manager.connect('bar.db')
 
-        print "connection " + repr(connection)
         self.assert_(connection.cursor() is not None)
         self.assert_(connection is connection2)
         self.assert_(connection2 is not connection3)
@@ -70,8 +70,6 @@ class PoolTest(PoolTestBase):
         connection = manager.connect('foo.db')
         connection2 = manager.connect('foo.db')
 
-        print "connection " + repr(connection)
-
         self.assert_(connection.cursor() is not None)
         self.assert_(connection is not connection2)
 
@@ -103,7 +101,8 @@ class PoolTest(PoolTestBase):
                 c2.close()
             else:
                 c2 = None
-
+                lazy_gc()
+                
             if useclose:
                 c1 = p.connect()
                 c2 = p.connect()
@@ -117,6 +116,8 @@ class PoolTest(PoolTestBase):
 
             # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced
             if isinstance(p, pool.QueuePool):
+                lazy_gc()
+
                 self.assert_(p.checkedout() == 0)
                 c1 = p.connect()
                 c2 = p.connect()
@@ -126,6 +127,7 @@ class PoolTest(PoolTestBase):
                 else:
                     c2 = None
                     c1 = None
+                    lazy_gc()
                 self.assert_(p.checkedout() == 0)
 
     def test_properties(self):
@@ -164,6 +166,8 @@ class PoolTest(PoolTestBase):
             def __init__(self):
                 if hasattr(self, 'connect'):
                     self.connect = self.inst_connect
+                if hasattr(self, 'first_connect'):
+                    self.first_connect = self.inst_first_connect
                 if hasattr(self, 'checkout'):
                     self.checkout = self.inst_checkout
                 if hasattr(self, 'checkin'):
@@ -171,14 +175,17 @@ class PoolTest(PoolTestBase):
                 self.clear()
             def clear(self):
                 self.connected = []
+                self.first_connected = []
                 self.checked_out = []
                 self.checked_in = []
-            def assert_total(innerself, conn, cout, cin):
+            def assert_total(innerself, conn, fconn, cout, cin):
                 self.assert_(len(innerself.connected) == conn)
+                self.assert_(len(innerself.first_connected) == fconn)
                 self.assert_(len(innerself.checked_out) == cout)
                 self.assert_(len(innerself.checked_in) == cin)
-            def assert_in(innerself, item, in_conn, in_cout, in_cin):
+            def assert_in(innerself, item, in_conn, in_fconn, in_cout, in_cin):
                 self.assert_((item in innerself.connected) == in_conn)
+                self.assert_((item in innerself.first_connected) == in_fconn)
                 self.assert_((item in innerself.checked_out) == in_cout)
                 self.assert_((item in innerself.checked_in) == in_cin)
             def inst_connect(self, con, record):
@@ -186,6 +193,11 @@ class PoolTest(PoolTestBase):
                 assert con is not None
                 assert record is not None
                 self.connected.append(con)
+            def inst_first_connect(self, con, record):
+                print "first_connect(%s, %s)" % (con, record)
+                assert con is not None
+                assert record is not None
+                self.first_connected.append(con)
             def inst_checkout(self, con, record, proxy):
                 print "checkout(%s, %s, %s)" % (con, record, proxy)
                 assert con is not None
@@ -203,6 +215,9 @@ class PoolTest(PoolTestBase):
         class ListenConnect(InstrumentingListener):
             def connect(self, con, record):
                 pass
+        class ListenFirstConnect(InstrumentingListener):
+            def first_connect(self, con, record):
+                pass
         class ListenCheckOut(InstrumentingListener):
             def checkout(self, con, record, proxy, num):
                 pass
@@ -214,40 +229,43 @@ class PoolTest(PoolTestBase):
             return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
                                   use_threadlocal=False, **kw)
 
-        def assert_listeners(p, total, conn, cout, cin):
+        def assert_listeners(p, total, conn, fconn, cout, cin):
             for instance in (p, p.recreate()):
                 self.assert_(len(instance.listeners) == total)
                 self.assert_(len(instance._on_connect) == conn)
+                self.assert_(len(instance._on_first_connect) == fconn)
                 self.assert_(len(instance._on_checkout) == cout)
                 self.assert_(len(instance._on_checkin) == cin)
 
         p = _pool()
-        assert_listeners(p, 0, 0, 0, 0)
+        assert_listeners(p, 0, 0, 0, 0, 0)
 
         p.add_listener(ListenAll())
-        assert_listeners(p, 1, 1, 1, 1)
+        assert_listeners(p, 1, 1, 1, 1, 1)
 
         p.add_listener(ListenConnect())
-        assert_listeners(p, 2, 2, 1, 1)
+        assert_listeners(p, 2, 2, 1, 1, 1)
+
+        p.add_listener(ListenFirstConnect())
+        assert_listeners(p, 3, 2, 2, 1, 1)
 
         p.add_listener(ListenCheckOut())
-        assert_listeners(p, 3, 2, 2, 1)
+        assert_listeners(p, 4, 2, 2, 2, 1)
 
         p.add_listener(ListenCheckIn())
-        assert_listeners(p, 4, 2, 2, 2)
+        assert_listeners(p, 5, 2, 2, 2, 2)
         del p
 
-        print "----"
         snoop = ListenAll()
         p = _pool(listeners=[snoop])
-        assert_listeners(p, 1, 1, 1, 1)
+        assert_listeners(p, 1, 1, 1, 1, 1)
 
         c = p.connect()
-        snoop.assert_total(1, 1, 0)
+        snoop.assert_total(1, 1, 1, 0)
         cc = c.connection
-        snoop.assert_in(cc, True, True, False)
+        snoop.assert_in(cc, True, True, True, False)
         c.close()
-        snoop.assert_in(cc, True, True, True)
+        snoop.assert_in(cc, True, True, True, True)
         del c, cc
 
         snoop.clear()
@@ -255,10 +273,11 @@ class PoolTest(PoolTestBase):
         # this one depends on immediate gc
         c = p.connect()
         cc = c.connection
-        snoop.assert_in(cc, False, True, False)
-        snoop.assert_total(0, 1, 0)
+        snoop.assert_in(cc, False, False, True, False)
+        snoop.assert_total(0, 0, 1, 0)
         del c, cc
-        snoop.assert_total(0, 1, 1)
+        lazy_gc()
+        snoop.assert_total(0, 0, 1, 1)
 
         p.dispose()
         snoop.clear()
@@ -266,44 +285,46 @@ class PoolTest(PoolTestBase):
         c = p.connect()
         c.close()
         c = p.connect()
-        snoop.assert_total(1, 2, 1)
+        snoop.assert_total(1, 0, 2, 1)
         c.close()
-        snoop.assert_total(1, 2, 2)
+        snoop.assert_total(1, 0, 2, 2)
 
         # invalidation
         p.dispose()
         snoop.clear()
 
         c = p.connect()
-        snoop.assert_total(1, 1, 0)
+        snoop.assert_total(1, 0, 1, 0)
         c.invalidate()
-        snoop.assert_total(1, 1, 1)
+        snoop.assert_total(1, 0, 1, 1)
         c.close()
-        snoop.assert_total(1, 1, 1)
+        snoop.assert_total(1, 0, 1, 1)
         del c
-        snoop.assert_total(1, 1, 1)
+        lazy_gc()
+        snoop.assert_total(1, 0, 1, 1)
         c = p.connect()
-        snoop.assert_total(2, 2, 1)
+        snoop.assert_total(2, 0, 2, 1)
         c.close()
         del c
-        snoop.assert_total(2, 2, 2)
+        lazy_gc()
+        snoop.assert_total(2, 0, 2, 2)
 
         # detached
         p.dispose()
         snoop.clear()
 
         c = p.connect()
-        snoop.assert_total(1, 1, 0)
+        snoop.assert_total(1, 0, 1, 0)
         c.detach()
-        snoop.assert_total(1, 1, 0)
+        snoop.assert_total(1, 0, 1, 0)
         c.close()
         del c
-        snoop.assert_total(1, 1, 0)
+        snoop.assert_total(1, 0, 1, 0)
         c = p.connect()
-        snoop.assert_total(2, 2, 0)
+        snoop.assert_total(2, 0, 2, 0)
         c.close()
         del c
-        snoop.assert_total(2, 2, 1)
+        snoop.assert_total(2, 0, 2, 1)
 
     def test_listeners_callables(self):
         dbapi = MockDBAPI()
@@ -362,262 +383,293 @@ class PoolTest(PoolTestBase):
         c.close()
         assert counts == [1, 2, 3]
 
+    def test_listener_after_oninit(self):
+        """Test that listeners are called after OnInit is removed"""
+        called = []
+        def listener(*args):
+            called.append(True)
+        listener.connect = listener
+        engine = create_engine(testing.db.url)
+        engine.pool.add_listener(listener)
+        engine.execute(select([1]))
+        assert called, "Listener not called on connect"
+
+
 class QueuePoolTest(PoolTestBase):
 
-   def testqueuepool_del(self):
-       self._do_testqueuepool(useclose=False)
-
-   def testqueuepool_close(self):
-       self._do_testqueuepool(useclose=True)
-
-   def _do_testqueuepool(self, useclose=False):
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
-
-       def status(pool):
-           tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
-           print "Pool size: %d  Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
-           return tup
-
-       c1 = p.connect()
-       self.assert_(status(p) == (3,0,-2,1))
-       c2 = p.connect()
-       self.assert_(status(p) == (3,0,-1,2))
-       c3 = p.connect()
-       self.assert_(status(p) == (3,0,0,3))
-       c4 = p.connect()
-       self.assert_(status(p) == (3,0,1,4))
-       c5 = p.connect()
-       self.assert_(status(p) == (3,0,2,5))
-       c6 = p.connect()
-       self.assert_(status(p) == (3,0,3,6))
-       if useclose:
-           c4.close()
-           c3.close()
-           c2.close()
-       else:
-           c4 = c3 = c2 = None
-       self.assert_(status(p) == (3,3,3,3))
-       if useclose:
-           c1.close()
-           c5.close()
-           c6.close()
-       else:
-           c1 = c5 = c6 = None
-       self.assert_(status(p) == (3,3,0,0))
-       c1 = p.connect()
-       c2 = p.connect()
-       self.assert_(status(p) == (3, 1, 0, 2), status(p))
-       if useclose:
-           c2.close()
-       else:
-           c2 = None
-       self.assert_(status(p) == (3, 2, 0, 1))
-
-   def test_timeout(self):
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
-       c1 = p.connect()
-       c2 = p.connect()
-       c3 = p.connect()
-       now = time.time()
-       try:
-           c4 = p.connect()
-           assert False
-       except tsa.exc.TimeoutError, e:
-           assert int(time.time() - now) == 2
-
-   def test_timeout_race(self):
-       # test a race condition where the initial connecting threads all race
-       # to queue.Empty, then block on the mutex.  each thread consumes a
-       # connection as they go in.  when the limit is reached, the remaining
-       # threads go in, and get TimeoutError; even though they never got to
-       # wait for the timeout on queue.get().  the fix involves checking the
-       # timeout again within the mutex, and if so, unlocking and throwing
-       # them back to the start of do_get()
-       p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
-       timeouts = []
-       def checkout():
-           for x in xrange(1):
-               now = time.time()
-               try:
-                   c1 = p.connect()
-               except tsa.exc.TimeoutError, e:
-                   timeouts.append(int(time.time()) - now)
-                   continue
-               time.sleep(4)
-               c1.close()
-
-       threads = []
-       for i in xrange(10):
-           th = threading.Thread(target=checkout)
-           th.start()
-           threads.append(th)
-       for th in threads:
-           th.join()
-
-       print timeouts
-       assert len(timeouts) > 0
-       for t in timeouts:
-           assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
-
-   def _test_overflow(self, thread_count, max_overflow):
-       def creator():
-           time.sleep(.05)
-           return mock_dbapi.connect()
-
-       p = pool.QueuePool(creator=creator,
-                          pool_size=3, timeout=2,
-                          max_overflow=max_overflow)
-       peaks = []
-       def whammy():
-           for i in range(10):
-               try:
-                   con = p.connect()
-                   time.sleep(.005)
-                   peaks.append(p.overflow())
-                   con.close()
-                   del con
-               except tsa.exc.TimeoutError:
-                   pass
-       threads = []
-       for i in xrange(thread_count):
-           th = threading.Thread(target=whammy)
-           th.start()
-           threads.append(th)
-       for th in threads:
-           th.join()
-
-       self.assert_(max(peaks) <= max_overflow)
-
-   def test_no_overflow(self):
-       self._test_overflow(40, 0)
-
-   def test_max_overflow(self):
-       self._test_overflow(40, 5)
-
-   def test_mixed_close(self):
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
-       c1 = p.connect()
-       c2 = p.connect()
-       assert c1 is c2
-       c1.close()
-       c2 = None
-       assert p.checkedout() == 1
-       c1 = None
-       assert p.checkedout() == 0
-
-   def test_weakref_kaboom(self):
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
-       c1 = p.connect()
-       c2 = p.connect()
-       c1.close()
-       c2 = None
-       del c1
-       del c2
-       gc.collect()
-       assert p.checkedout() == 0
-       c3 = p.connect()
-       assert c3 is not None
-
-   def test_trick_the_counter(self):
-       """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
-       with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
-       ambiguous counter.  i.e. its not true reference counting."""
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
-       c1 = p.connect()
-       c2 = p.connect()
-       assert c1 is c2
-       c1.close()
-       c2 = p.connect()
-       c2.close()
-       self.assert_(p.checkedout() != 0)
-
-       c2.close()
-       self.assert_(p.checkedout() == 0)
-
-   def test_recycle(self):
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
-
-       c1 = p.connect()
-       c_id = id(c1.connection)
-       c1.close()
-       c2 = p.connect()
-       assert id(c2.connection) == c_id
-       c2.close()
-       time.sleep(4)
-       c3= p.connect()
-       assert id(c3.connection) != c_id
-
-   def test_invalidate(self):
-       dbapi = MockDBAPI()
-       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-       c1 = p.connect()
-       c_id = c1.connection.id
-       c1.close(); c1=None
-       c1 = p.connect()
-       assert c1.connection.id == c_id
-       c1.invalidate()
-       c1 = None
-
-       c1 = p.connect()
-       assert c1.connection.id != c_id
-
-   def test_recreate(self):
-       dbapi = MockDBAPI()
-       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-       p2 = p.recreate()
-       assert p2.size() == 1
-       assert p2._use_threadlocal is False
-       assert p2._max_overflow == 0
-
-   def test_reconnect(self):
-       """tests reconnect operations at the pool level.  SA's engine/dialect includes another
-       layer of reconnect support for 'database was lost' errors."""
+    def testqueuepool_del(self):
+        self._do_testqueuepool(useclose=False)
+
+    def testqueuepool_close(self):
+        self._do_testqueuepool(useclose=True)
+
+    def _do_testqueuepool(self, useclose=False):
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
+
+        def status(pool):
+            tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
+            print "Pool size: %d  Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
+            return tup
+
+        c1 = p.connect()
+        self.assert_(status(p) == (3,0,-2,1))
+        c2 = p.connect()
+        self.assert_(status(p) == (3,0,-1,2))
+        c3 = p.connect()
+        self.assert_(status(p) == (3,0,0,3))
+        c4 = p.connect()
+        self.assert_(status(p) == (3,0,1,4))
+        c5 = p.connect()
+        self.assert_(status(p) == (3,0,2,5))
+        c6 = p.connect()
+        self.assert_(status(p) == (3,0,3,6))
+        if useclose:
+            c4.close()
+            c3.close()
+            c2.close()
+        else:
+            c4 = c3 = c2 = None
+            lazy_gc()
+            
+        self.assert_(status(p) == (3,3,3,3))
+        if useclose:
+            c1.close()
+            c5.close()
+            c6.close()
+        else:
+            c1 = c5 = c6 = None
+            lazy_gc()
+            
+        self.assert_(status(p) == (3,3,0,0))
+        
+        c1 = p.connect()
+        c2 = p.connect()
+        self.assert_(status(p) == (3, 1, 0, 2), status(p))
+        if useclose:
+            c2.close()
+        else:
+            c2 = None
+            lazy_gc()
+            
+        self.assert_(status(p) == (3, 2, 0, 1))  
+
+        c1.close()
+       
+        lazy_gc()
+        assert not pool._refs
        
-       dbapi = MockDBAPI()
-       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-       c1 = p.connect()
-       c_id = c1.connection.id
-       c1.close(); c1=None
-
-       c1 = p.connect()
-       assert c1.connection.id == c_id
-       dbapi.raise_error = True
-       c1.invalidate()
-       c1 = None
-
-       c1 = p.connect()
-       assert c1.connection.id != c_id
-
-   def test_detach(self):
-       dbapi = MockDBAPI()
-       p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-
-       c1 = p.connect()
-       c1.detach()
-       c_id = c1.connection.id
-
-       c2 = p.connect()
-       assert c2.connection.id != c1.connection.id
-       dbapi.raise_error = True
-
-       c2.invalidate()
-       c2 = None
-
-       c2 = p.connect()
-       assert c2.connection.id != c1.connection.id
-
-       con = c1.connection
-
-       assert not con.closed
-       c1.close()
-       assert con.closed
-
-   def test_threadfairy(self):
-       p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
-       c1 = p.connect()
-       c1.close()
-       c2 = p.connect()
-       assert c2.connection is not None
+    def test_timeout(self):
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
+        c1 = p.connect()
+        c2 = p.connect()
+        c3 = p.connect()
+        now = time.time()
+        try:
+            c4 = p.connect()
+            assert False
+        except tsa.exc.TimeoutError, e:
+            assert int(time.time() - now) == 2
+
+    def test_timeout_race(self):
+        # test a race condition where the initial connecting threads all race
+        # to queue.Empty, then block on the mutex.  each thread consumes a
+        # connection as they go in.  when the limit is reached, the remaining
+        # threads go in, and get TimeoutError; even though they never got to
+        # wait for the timeout on queue.get().  the fix involves checking the
+        # timeout again within the mutex, and if so, unlocking and throwing
+        # them back to the start of do_get()
+        p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
+        timeouts = []
+        def checkout():
+            for x in xrange(1):
+                now = time.time()
+                try:
+                    c1 = p.connect()
+                except tsa.exc.TimeoutError, e:
+                    timeouts.append(int(time.time()) - now)
+                    continue
+                time.sleep(4)
+                c1.close()  
+
+        threads = []
+        for i in xrange(10):
+            th = threading.Thread(target=checkout)
+            th.start()
+            threads.append(th)
+        for th in threads:
+            th.join() 
+
+        print timeouts
+        assert len(timeouts) > 0
+        for t in timeouts:
+            assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
+
+    def _test_overflow(self, thread_count, max_overflow):
+        def creator():
+            time.sleep(.05)
+            return mock_dbapi.connect()
+        p = pool.QueuePool(creator=creator,
+                           pool_size=3, timeout=2,
+                           max_overflow=max_overflow)
+        peaks = []
+        def whammy():
+            for i in range(10):
+                try:
+                    con = p.connect()
+                    time.sleep(.005)
+                    peaks.append(p.overflow())
+                    con.close()
+                    del con
+                except tsa.exc.TimeoutError:
+                    pass
+        threads = []
+        for i in xrange(thread_count):
+            th = threading.Thread(target=whammy)
+            th.start()
+            threads.append(th)
+        for th in threads:
+            th.join()
+        self.assert_(max(peaks) <= max_overflow)
+        
+        lazy_gc()
+        assert not pool._refs
+    def test_no_overflow(self):
+        self._test_overflow(40, 0)
+    def test_max_overflow(self):
+        self._test_overflow(40, 5)
+    def test_mixed_close(self):
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+        c1 = p.connect()
+        c2 = p.connect()
+        assert c1 is c2
+        c1.close()
+        c2 = None
+        assert p.checkedout() == 1
+        c1 = None
+        lazy_gc()
+        assert p.checkedout() == 0
+        
+        lazy_gc()
+        assert not pool._refs
+    def test_weakref_kaboom(self):
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+        c1 = p.connect()
+        c2 = p.connect()
+        c1.close()
+        c2 = None
+        del c1
+        del c2
+        gc_collect()
+        assert p.checkedout() == 0
+        c3 = p.connect()
+        assert c3 is not None
+    def test_trick_the_counter(self):
+        """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
+        with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
+        ambiguous counter.  i.e. its not true reference counting."""
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+        c1 = p.connect()
+        c2 = p.connect()
+        assert c1 is c2
+        c1.close()
+        c2 = p.connect()
+        c2.close()
+        self.assert_(p.checkedout() != 0)
+        c2.close()
+        self.assert_(p.checkedout() == 0)
+    def test_recycle(self):
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close()
+        c2 = p.connect()
+        assert id(c2.connection) == c_id
+        c2.close()
+        time.sleep(4)
+        c3= p.connect()
+        assert id(c3.connection) != c_id
+    def test_invalidate(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+        c1 = p.connect()
+        c_id = c1.connection.id
+        c1.close(); c1=None
+        c1 = p.connect()
+        assert c1.connection.id == c_id
+        c1.invalidate()
+        c1 = None
+        c1 = p.connect()
+        assert c1.connection.id != c_id
+    def test_recreate(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+        p2 = p.recreate()
+        assert p2.size() == 1
+        assert p2._use_threadlocal is False
+        assert p2._max_overflow == 0
+    def test_reconnect(self):
+        """tests reconnect operations at the pool level.  SA's engine/dialect includes another
+        layer of reconnect support for 'database was lost' errors."""
+        
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+        c1 = p.connect()
+        c_id = c1.connection.id
+        c1.close(); c1=None
+        c1 = p.connect()
+        assert c1.connection.id == c_id
+        dbapi.raise_error = True
+        c1.invalidate()
+        c1 = None
+        c1 = p.connect()
+        assert c1.connection.id != c_id
+    def test_detach(self):
+        dbapi = MockDBAPI()
+        p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+        c1 = p.connect()
+        c1.detach()
+        c_id = c1.connection.id
+        c2 = p.connect()
+        assert c2.connection.id != c1.connection.id
+        dbapi.raise_error = True
+        c2.invalidate()
+        c2 = None
+        c2 = p.connect()
+        assert c2.connection.id != c1.connection.id
+        con = c1.connection
+        assert not con.closed
+        c1.close()
+        assert con.closed
+    def test_threadfairy(self):
+        p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+        c1 = p.connect()
+        c1.close()
+        c2 = p.connect()
+        assert c2.connection is not None
 
 class SingletonThreadPoolTest(PoolTestBase):
     def test_cleanup(self):
index 3a525c2a702e1f835c160f0e3a4d30add230b7c4..6afd7151554f12939c25423368a3c37de84138d1 100644 (file)
@@ -1,12 +1,13 @@
 from sqlalchemy.test.testing import eq_
+import time
 import weakref
 from sqlalchemy import select, MetaData, Integer, String, pool
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
 import sqlalchemy as tsa
 from sqlalchemy.test import TestBase, testing, engines
-import time
-import gc
+from sqlalchemy.test.util import gc_collect
+
 
 class MockDisconnect(Exception):
     pass
@@ -54,7 +55,7 @@ class MockReconnectTest(TestBase):
         dbapi = MockDBAPI()
 
         # create engine using our current dburi
-        db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+        db = tsa.create_engine('postgresql://foo:bar@localhost/test', module=dbapi, _initialize=False)
 
         # monkeypatch disconnect checker
         db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
@@ -98,7 +99,7 @@ class MockReconnectTest(TestBase):
         assert id(db.pool) != pid
 
         # ensure all connections closed (pool was recycled)
-        gc.collect()
+        gc_collect()
         assert len(dbapi.connections) == 0
 
         conn =db.connect()
@@ -118,7 +119,7 @@ class MockReconnectTest(TestBase):
             pass
 
         # assert was invalidated
-        gc.collect()
+        gc_collect()
         assert len(dbapi.connections) == 0
         assert not conn.closed
         assert conn.invalidated
@@ -168,7 +169,7 @@ class MockReconnectTest(TestBase):
         assert conn.invalidated
 
         # ensure all connections closed (pool was recycled)
-        gc.collect()
+        gc_collect()
         assert len(dbapi.connections) == 0
 
         # test reconnects
@@ -334,7 +335,8 @@ class InvalidateDuringResultTest(TestBase):
         meta.drop_all()
         engine.dispose()
 
-    @testing.fails_on('mysql', 'FIXME: unknown')
+    @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close")
+    @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close")
     def test_invalidate_on_results(self):
         conn = engine.connect()
 
@@ -344,7 +346,7 @@ class InvalidateDuringResultTest(TestBase):
 
         engine.test_shutdown()
         try:
-            result.fetchone()
+            print "ghost result: %r" % result.fetchone()
             assert False
         except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
index ea80776a6a1bb4217fb7542a3738a9c722468537..dff9fa1bb6fcf4f51c5bafb71f87c151b6c0bcf9 100644 (file)
@@ -1,17 +1,22 @@
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import StringIO, unicodedata
 import sqlalchemy as sa
+from sqlalchemy import types as sql_types
+from sqlalchemy import schema
+from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy import MetaData
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
 import sqlalchemy as tsa
 from sqlalchemy.test import TestBase, ComparesTables, testing, engines
 
+create_inspector = Inspector.from_engine
 
 metadata, users = None, None
 
 class ReflectionTest(TestBase, ComparesTables):
 
+    @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+')
     @testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely')
     def test_basic_reflection(self):
         meta = MetaData(testing.db)
@@ -22,16 +27,16 @@ class ReflectionTest(TestBase, ComparesTables):
             Column('test1', sa.CHAR(5), nullable=False),
             Column('test2', sa.Float(5), nullable=False),
             Column('test3', sa.Text),
-            Column('test4', sa.Numeric, nullable = False),
-            Column('test5', sa.DateTime),
+            Column('test4', sa.Numeric(10, 2), nullable = False),
+            Column('test5', sa.Date),
             Column('parent_user_id', sa.Integer,
                    sa.ForeignKey('engine_users.user_id')),
-            Column('test6', sa.DateTime, nullable=False),
+            Column('test6', sa.Date, nullable=False),
             Column('test7', sa.Text),
             Column('test8', sa.Binary),
             Column('test_passivedefault2', sa.Integer, server_default='5'),
             Column('test9', sa.Binary(100)),
-            Column('test_numeric', sa.Numeric()),
+            Column('test10', sa.Numeric(10, 2)),
             test_needs_fk=True,
         )
 
@@ -52,9 +57,35 @@ class ReflectionTest(TestBase, ComparesTables):
             self.assert_tables_equal(users, reflected_users)
             self.assert_tables_equal(addresses, reflected_addresses)
         finally:
-            addresses.drop()
-            users.drop()
-
+            meta.drop_all()
+    
+    def test_two_foreign_keys(self):
+        meta = MetaData(testing.db)
+        t1 = Table('t1', meta, 
+                Column('id', sa.Integer, primary_key=True),
+                Column('t2id', sa.Integer, sa.ForeignKey('t2.id')),
+                Column('t3id', sa.Integer, sa.ForeignKey('t3.id')),
+                test_needs_fk=True
+        )
+        t2 = Table('t2', meta, 
+                Column('id', sa.Integer, primary_key=True),
+                test_needs_fk=True
+        )
+        t3 = Table('t3', meta, 
+                Column('id', sa.Integer, primary_key=True),
+                test_needs_fk=True
+        )
+        meta.create_all()
+        try:
+            meta2 = MetaData()
+            t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')]
+            
+            assert t1r.c.t2id.references(t2r.c.id)
+            assert t1r.c.t3id.references(t3r.c.id)
+            
+        finally:
+            meta.drop_all()
+            
     def test_include_columns(self):
         meta = MetaData(testing.db)
         foo = Table('foo', meta, *[Column(n, sa.String(30))
@@ -84,26 +115,68 @@ class ReflectionTest(TestBase, ComparesTables):
         finally:
             meta.drop_all()
 
+    @testing.emits_warning(r".*omitted columns")
+    def test_include_columns_indexes(self):
+        m = MetaData(testing.db)
+        
+        t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer))
+        sa.Index('foobar', t1.c.a, t1.c.b)
+        sa.Index('bat', t1.c.a)
+        m.create_all()
+        try:
+            m2 = MetaData(testing.db)
+            t2 = Table('t1', m2, autoload=True)
+            assert len(t2.indexes) == 2
 
+            m2 = MetaData(testing.db)
+            t2 = Table('t1', m2, autoload=True, include_columns=['a'])
+            assert len(t2.indexes) == 1
+
+            m2 = MetaData(testing.db)
+            t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b'])
+            assert len(t2.indexes) == 2
+        finally:
+            m.drop_all()
+
+    def test_autoincrement_col(self):
+        """test that 'autoincrement' is reflected according to sqla's policy.
+        
+        Don't mark this test as unsupported for any backend !
+        
+        (technically it fails with MySQL InnoDB since "id" comes before "id2")
+        
+        """
+        
+        meta = MetaData(testing.db)
+        t1 = Table('test', meta,
+            Column('id', sa.Integer, primary_key=True),
+            Column('data', sa.String(50)),
+        )
+        t2 = Table('test2', meta,
+            Column('id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True),
+            Column('id2', sa.Integer, primary_key=True),
+            Column('data', sa.String(50)),
+        )
+        meta.create_all()
+        try:
+            m2 = MetaData(testing.db)
+            t1a = Table('test', m2, autoload=True)
+            assert t1a._autoincrement_column is t1a.c.id
+            
+            t2a = Table('test2', m2, autoload=True)
+            assert t2a._autoincrement_column is t2a.c.id2
+            
+        finally:
+            meta.drop_all()
+            
     def test_unknown_types(self):
         meta = MetaData(testing.db)
         t = Table("test", meta,
             Column('foo', sa.DateTime))
 
-        import sys
-        dialect_module = sys.modules[testing.db.dialect.__module__]
-
-        # we're relying on the presence of "ischema_names" in the
-        # dialect module, else we can't test this.  we need to be able
-        # to get the dialect to not be aware of some type so we temporarily
-        # monkeypatch.  not sure what a better way for this could be,
-        # except for an established dialect hook or dialect-specific tests
-        if not hasattr(dialect_module, 'ischema_names'):
-            return
-
-        ischema_names = dialect_module.ischema_names
+        ischema_names = testing.db.dialect.ischema_names
         t.create()
-        dialect_module.ischema_names = {}
+        testing.db.dialect.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
             assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
@@ -115,7 +188,7 @@ class ReflectionTest(TestBase, ComparesTables):
                 assert t3.c.foo.type.__class__ == sa.types.NullType
 
         finally:
-            dialect_module.ischema_names = ischema_names
+            testing.db.dialect.ischema_names = ischema_names
             t.drop()
 
     def test_basic_override(self):
@@ -578,7 +651,6 @@ class ReflectionTest(TestBase, ComparesTables):
             m9.reflect()
             self.assert_(not m9.tables)
 
-    @testing.fails_on_everything_except('postgres', 'mysql')
     def test_index_reflection(self):
         m1 = MetaData(testing.db)
         t1 = Table('party', m1,
@@ -698,7 +770,7 @@ class UnicodeReflectionTest(TestBase):
     def test_basic(self):
         try:
             # the 'convert_unicode' should not get in the way of the reflection
-            # process.  reflecttable for oracle, postgres (others?) expect non-unicode
+            # process.  reflecttable for oracle, postgresql (others?) expect non-unicode
             # strings in result sets/bind params
             bind = engines.utf8_engine(options={'convert_unicode':True})
             metadata = MetaData(bind)
@@ -713,7 +785,8 @@ class UnicodeReflectionTest(TestBase):
             metadata.create_all()
 
             reflected = set(bind.table_names())
-            if not names.issubset(reflected):
+            # Jython 2.5 on Java 5 lacks unicodedata.normalize
+            if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'):
                 # Python source files in the utf-8 coding seem to normalize
                 # literals as NFC (and the above are explicitly NFC).  Maybe
                 # this database normalizes NFD on reflection.
@@ -741,23 +814,15 @@ class SchemaTest(TestBase):
             Column('col1', sa.Integer, primary_key=True),
             Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
             schema='someschema')
-        # ensure this doesnt crash
-        print [t for t in metadata.sorted_tables]
-        buf = StringIO.StringIO()
-        def foo(s, p=None):
-            buf.write(s)
-        gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
-        gen = gen.dialect.schemagenerator(gen.dialect, gen)
-        gen.traverse(table1)
-        gen.traverse(table2)
-        buf = buf.getvalue()
-        print buf
+
+        t1 = str(schema.CreateTable(table1).compile(bind=testing.db))
+        t2 = str(schema.CreateTable(table2).compile(bind=testing.db))
         if testing.db.dialect.preparer(testing.db.dialect).omit_schema:
-            assert buf.index("CREATE TABLE table1") > -1
-            assert buf.index("CREATE TABLE table2") > -1
+            assert t1.index("CREATE TABLE table1") > -1
+            assert t2.index("CREATE TABLE table2") > -1
         else:
-            assert buf.index("CREATE TABLE someschema.table1") > -1
-            assert buf.index("CREATE TABLE someschema.table2") > -1
+            assert t1.index("CREATE TABLE someschema.table1") > -1
+            assert t2.index("CREATE TABLE someschema.table2") > -1
 
     @testing.crashes('firebird', 'No schema support')
     @testing.fails_on('sqlite', 'FIXME: unknown')
@@ -767,9 +832,9 @@ class SchemaTest(TestBase):
     def test_explicit_default_schema(self):
         engine = testing.db
 
-        if testing.against('mysql'):
+        if testing.against('mysql+mysqldb'):
             schema = testing.db.url.database
-        elif testing.against('postgres'):
+        elif testing.against('postgresql'):
             schema = 'public'
         elif testing.against('sqlite'):
             # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc.,
@@ -820,4 +885,324 @@ class HasSequenceTest(TestBase):
         metadata.drop_all(bind=testing.db)
         eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
 
+# Tests related to engine.reflection
+
+def get_schema():
+    if testing.against('oracle'):
+        return 'scott'
+    return 'test_schema'
+
+def createTables(meta, schema=None):
+    if schema:
+        parent_user_id = Column('parent_user_id', sa.Integer,
+            sa.ForeignKey('%s.users.user_id' % schema)
+        )
+    else:
+        parent_user_id = Column('parent_user_id', sa.Integer,
+            sa.ForeignKey('users.user_id')
+        )
+
+    users = Table('users', meta,
+        Column('user_id', sa.INT, primary_key=True),
+        Column('user_name', sa.VARCHAR(20), nullable=False),
+        Column('test1', sa.CHAR(5), nullable=False),
+        Column('test2', sa.Float(5), nullable=False),
+        Column('test3', sa.Text),
+        Column('test4', sa.Numeric(10, 2), nullable = False),
+        Column('test5', sa.DateTime),
+        Column('test5-1', sa.TIMESTAMP),
+        parent_user_id,
+        Column('test6', sa.DateTime, nullable=False),
+        Column('test7', sa.Text),
+        Column('test8', sa.Binary),
+        Column('test_passivedefault2', sa.Integer, server_default='5'),
+        Column('test9', sa.Binary(100)),
+        Column('test10', sa.Numeric(10, 2)),
+        schema=schema,
+        test_needs_fk=True,
+    )
+    addresses = Table('email_addresses', meta,
+        Column('address_id', sa.Integer, primary_key = True),
+        Column('remote_user_id', sa.Integer,
+               sa.ForeignKey(users.c.user_id)),
+        Column('email_address', sa.String(20)),
+        schema=schema,
+        test_needs_fk=True,
+    )
+    return (users, addresses)
+
+def createIndexes(con, schema=None):
+    fullname = 'users'
+    if schema:
+        fullname = "%s.%s" % (schema, 'users')
+    query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname
+    con.execute(sa.sql.text(query))
+
+def createViews(con, schema=None):
+    for table_name in ('users', 'email_addresses'):
+        fullname = table_name
+        if schema:
+            fullname = "%s.%s" % (schema, table_name)
+        view_name = fullname + '_v'
+        query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name,
+                                                                   fullname)
+        con.execute(sa.sql.text(query))
+
+def dropViews(con, schema=None):
+    for table_name in ('email_addresses', 'users'):
+        fullname = table_name
+        if schema:
+            fullname = "%s.%s" % (schema, table_name)
+        view_name = fullname + '_v'
+        query = "DROP VIEW %s" % view_name
+        con.execute(sa.sql.text(query))
+
+
+class ComponentReflectionTest(TestBase):
+
+    @testing.requires.schemas
+    def test_get_schema_names(self):
+        meta = MetaData(testing.db)
+        insp = Inspector(meta.bind)
+        
+        self.assert_(get_schema() in insp.get_schema_names())
+
+    def _test_get_table_names(self, schema=None, table_type='table',
+                              order_by=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schema)
+        meta.create_all()
+        createViews(meta.bind, schema)
+        try:
+            insp = Inspector(meta.bind)
+            if table_type == 'view':
+                table_names = insp.get_view_names(schema)
+                table_names.sort()
+                answer = ['email_addresses_v', 'users_v']
+            else:
+                table_names = insp.get_table_names(schema,
+                                                   order_by=order_by)
+                table_names.sort()
+                if order_by == 'foreign_key':
+                    answer = ['users', 'email_addresses']
+                else:
+                    answer = ['email_addresses', 'users']
+            eq_(table_names, answer)
+        finally:
+            dropViews(meta.bind, schema)
+            addresses.drop()
+            users.drop()
+
+    def test_get_table_names(self):
+        self._test_get_table_names()
+
+    @testing.requires.schemas
+    def test_get_table_names_with_schema(self):
+        self._test_get_table_names(get_schema())
+
+    def test_get_view_names(self):
+        self._test_get_table_names(table_type='view')
+
+    @testing.requires.schemas
+    def test_get_view_names_with_schema(self):
+        self._test_get_table_names(get_schema(), table_type='view')
+
+    def _test_get_columns(self, schema=None, table_type='table'):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schema)
+        table_names = ['users', 'email_addresses']
+        meta.create_all()
+        if table_type == 'view':
+            createViews(meta.bind, schema)
+            table_names = ['users_v', 'email_addresses_v']
+        try:
+            insp = Inspector(meta.bind)
+            for (table_name, table) in zip(table_names, (users, addresses)):
+                schema_name = schema
+                cols = insp.get_columns(table_name, schema=schema_name)
+                self.assert_(len(cols) > 0, len(cols))
+                # should be in order
+                for (i, col) in enumerate(table.columns):
+                    eq_(col.name, cols[i]['name'])
+                    ctype = cols[i]['type'].__class__
+                    ctype_def = col.type
+                    if isinstance(ctype_def, sa.types.TypeEngine):
+                        ctype_def = ctype_def.__class__
+                        
+                    # Oracle returns Date for DateTime.
+                    if testing.against('oracle') \
+                        and ctype_def in (sql_types.Date, sql_types.DateTime):
+                            ctype_def = sql_types.Date
+                    
+                    # assert that the desired type and return type
+                    # share a base within one of the generic types.
+                    self.assert_(
+                        len(
+                            set(
+                                ctype.__mro__
+                            ).intersection(ctype_def.__mro__)
+                            .intersection([sql_types.Integer, sql_types.Numeric, 
+                                            sql_types.DateTime, sql_types.Date, sql_types.Time, 
+                                            sql_types.String, sql_types.Binary])
+                            ) > 0
+                    ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'],
+                                          ctype)))
+        finally:
+            if table_type == 'view':
+                dropViews(meta.bind, schema)
+            addresses.drop()
+            users.drop()
+
+    def test_get_columns(self):
+        self._test_get_columns()
+
+    @testing.requires.schemas
+    def test_get_columns_with_schema(self):
+        self._test_get_columns(schema=get_schema())
+
+    def test_get_view_columns(self):
+        self._test_get_columns(table_type='view')
+
+    @testing.requires.schemas
+    def test_get_view_columns_with_schema(self):
+        self._test_get_columns(schema=get_schema(), table_type='view')
+
+    def _test_get_primary_keys(self, schema=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schema)
+        meta.create_all()
+        insp = Inspector(meta.bind)
+        try:
+            users_pkeys = insp.get_primary_keys(users.name,
+                                                schema=schema)
+            eq_(users_pkeys,  ['user_id'])
+            addr_pkeys = insp.get_primary_keys(addresses.name,
+                                               schema=schema)
+            eq_(addr_pkeys,  ['address_id'])
+
+        finally:
+            addresses.drop()
+            users.drop()
+
+    def test_get_primary_keys(self):
+        self._test_get_primary_keys()
+
+    @testing.fails_on('sqlite', 'no schemas')
+    def test_get_primary_keys_with_schema(self):
+        self._test_get_primary_keys(schema=get_schema())
+
+    def _test_get_foreign_keys(self, schema=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schema)
+        meta.create_all()
+        insp = Inspector(meta.bind)
+        try:
+            expected_schema = schema
+            # users
+            users_fkeys = insp.get_foreign_keys(users.name,
+                                                schema=schema)
+            fkey1 = users_fkeys[0]
+            self.assert_(fkey1['name'] is not None)
+            eq_(fkey1['referred_schema'], expected_schema)
+            eq_(fkey1['referred_table'], users.name)
+            eq_(fkey1['referred_columns'], ['user_id', ])
+            eq_(fkey1['constrained_columns'], ['parent_user_id'])
+            #addresses
+            addr_fkeys = insp.get_foreign_keys(addresses.name,
+                                               schema=schema)
+            fkey1 = addr_fkeys[0]
+            self.assert_(fkey1['name'] is not None)
+            eq_(fkey1['referred_schema'], expected_schema)
+            eq_(fkey1['referred_table'], users.name)
+            eq_(fkey1['referred_columns'], ['user_id', ])
+            eq_(fkey1['constrained_columns'], ['remote_user_id'])
+        finally:
+            addresses.drop()
+            users.drop()
+
+    def test_get_foreign_keys(self):
+        self._test_get_foreign_keys()
+
+    @testing.requires.schemas
+    def test_get_foreign_keys_with_schema(self):
+        self._test_get_foreign_keys(schema=get_schema())
+
+    def _test_get_indexes(self, schema=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schema)
+        meta.create_all()
+        createIndexes(meta.bind, schema)
+        try:
+            # The database may decide to create indexes for foreign keys, etc.
+            # so there may be more indexes than expected.
+            insp = Inspector(meta.bind)
+            indexes = insp.get_indexes('users', schema=schema)
+            indexes.sort()
+            expected_indexes = [
+                {'unique': False,
+                 'column_names': ['test1', 'test2'],
+                 'name': 'users_t_idx'}]
+            index_names = [d['name'] for d in indexes]
+            for e_index in expected_indexes:
+                assert e_index['name'] in index_names
+                index = indexes[index_names.index(e_index['name'])]
+                for key in e_index:
+                    eq_(e_index[key], index[key])
+
+        finally:
+            addresses.drop()
+            users.drop()
+
+    def test_get_indexes(self):
+        self._test_get_indexes()
+
+    @testing.requires.schemas
+    def test_get_indexes_with_schema(self):
+        self._test_get_indexes(schema=get_schema())
+
+    def _test_get_view_definition(self, schema=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schema)
+        meta.create_all()
+        createViews(meta.bind, schema)
+        view_name1 = 'users_v'
+        view_name2 = 'email_addresses_v'
+        try:
+            insp = Inspector(meta.bind)
+            v1 = insp.get_view_definition(view_name1, schema=schema)
+            self.assert_(v1)
+            v2 = insp.get_view_definition(view_name2, schema=schema)
+            self.assert_(v2)
+        finally:
+            dropViews(meta.bind, schema)
+            addresses.drop()
+            users.drop()
+
+    def test_get_view_definition(self):
+        self._test_get_view_definition()
+
+    @testing.requires.schemas
+    def test_get_view_definition_with_schema(self):
+        self._test_get_view_definition(schema=get_schema())
+
+    def _test_get_table_oid(self, table_name, schema=None):
+        if testing.against('postgresql'):
+            meta = MetaData(testing.db)
+            (users, addresses) = createTables(meta, schema)
+            meta.create_all()
+            try:
+                insp = create_inspector(meta.bind)
+                oid = insp.get_table_oid(table_name, schema)
+                self.assert_(isinstance(oid, (int, long)))
+            finally:
+                addresses.drop()
+                users.drop()
+
+    def test_get_table_oid(self):
+        self._test_get_table_oid('users')
+
+    @testing.requires.schemas
+    def test_get_table_oid_with_schema(self):
+        self._test_get_table_oid('users', schema=get_schema())
+
 
index 6698259a45d52d2d9bb5e10bf52b6ee58a2609e2..8e3f3412d653b28c7049a6008f0a1a339cd41992 100644 (file)
@@ -20,7 +20,8 @@ class TransactionTest(TestBase):
         users.create(testing.db)
 
     def teardown(self):
-        testing.db.connect().execute(users.delete())
+        testing.db.execute(users.delete()).close()
+
     @classmethod
     def teardown_class(cls):
         users.drop(testing.db)
@@ -40,6 +41,7 @@ class TransactionTest(TestBase):
         result = connection.execute("select * from query_users")
         assert len(result.fetchall()) == 3
         transaction.commit()
+        connection.close()
 
     def test_rollback(self):
         """test a basic rollback"""
@@ -176,6 +178,7 @@ class TransactionTest(TestBase):
         connection.close()
 
     @testing.requires.savepoints
+    @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock')
     def test_nested_subtransaction_commit(self):
         connection = testing.db.connect()
         transaction = connection.begin()
@@ -274,6 +277,7 @@ class TransactionTest(TestBase):
         connection.close()
 
     @testing.requires.two_phase_transactions
+    @testing.crashes('mysql+zxjdbc', 'Deadlocks, causing subsequent tests to fail')
     @testing.fails_on('mysql', 'FIXME: unknown')
     def test_two_phase_recover(self):
         # MySQL recovery doesn't currently seem to work correctly
@@ -369,7 +373,7 @@ class ExplicitAutoCommitTest(TestBase):
     Requires PostgreSQL so that we may define a custom function which modifies the database.
     """
 
-    __only_on__ = 'postgres'
+    __only_on__ = 'postgresql'
 
     @classmethod
     def setup_class(cls):
@@ -380,7 +384,7 @@ class ExplicitAutoCommitTest(TestBase):
         testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql")
 
     def teardown(self):
-        foo.delete().execute()
+        foo.delete().execute().close()
 
     @classmethod
     def teardown_class(cls):
@@ -453,8 +457,10 @@ class TLTransactionTest(TestBase):
             test_needs_acid=True,
         )
         users.create(tlengine)
+
     def teardown(self):
-        tlengine.execute(users.delete())
+        tlengine.execute(users.delete()).close()
+
     @classmethod
     def teardown_class(cls):
         users.drop(tlengine)
@@ -497,6 +503,7 @@ class TLTransactionTest(TestBase):
         try:
             assert len(result.fetchall()) == 0
         finally:
+            c.close()
             external_connection.close()
 
     def test_rollback(self):
@@ -530,7 +537,9 @@ class TLTransactionTest(TestBase):
             external_connection.close()
 
     def test_commits(self):
-        assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0
+        connection = tlengine.connect()
+        assert connection.execute("select count(1) from query_users").scalar() == 0
+        connection.close()
 
         connection = tlengine.contextual_connect()
         transaction = connection.begin()
@@ -547,6 +556,7 @@ class TLTransactionTest(TestBase):
         l = result.fetchall()
         assert len(l) == 3, "expected 3 got %d" % len(l)
         transaction.commit()
+        connection.close()
 
     def test_rollback_off_conn(self):
         # test that a TLTransaction opened off a TLConnection allows that
@@ -563,6 +573,7 @@ class TLTransactionTest(TestBase):
         try:
             assert len(result.fetchall()) == 0
         finally:
+            conn.close()
             external_connection.close()
 
     def test_morerollback_off_conn(self):
@@ -581,6 +592,8 @@ class TLTransactionTest(TestBase):
         try:
             assert len(result.fetchall()) == 0
         finally:
+            conn.close()
+            conn2.close()
             external_connection.close()
 
     def test_commit_off_connection(self):
@@ -596,6 +609,7 @@ class TLTransactionTest(TestBase):
         try:
             assert len(result.fetchall()) == 3
         finally:
+            conn.close()
             external_connection.close()
 
     def test_nesting(self):
@@ -712,8 +726,10 @@ class ForUpdateTest(TestBase):
             test_needs_acid=True,
         )
         counters.create(testing.db)
+
     def teardown(self):
-        testing.db.connect().execute(counters.delete())
+        testing.db.execute(counters.delete()).close()
+
     @classmethod
     def teardown_class(cls):
         counters.drop(testing.db)
@@ -726,7 +742,7 @@ class ForUpdateTest(TestBase):
         for i in xrange(count):
             trans = con.begin()
             try:
-                existing = con.execute(sel).fetchone()
+                existing = con.execute(sel).first()
                 incr = existing['counter_value'] + 1
 
                 time.sleep(delay)
@@ -734,7 +750,7 @@ class ForUpdateTest(TestBase):
                                             values={'counter_value':incr}))
                 time.sleep(delay)
 
-                readback = con.execute(sel).fetchone()
+                readback = con.execute(sel).first()
                 if (readback['counter_value'] != incr):
                     raise AssertionError("Got %s post-update, expected %s" %
                                          (readback['counter_value'], incr))
@@ -778,7 +794,7 @@ class ForUpdateTest(TestBase):
         self.assert_(len(errors) == 0)
 
         sel = counters.select(whereclause=counters.c.counter_id==1)
-        final = db.execute(sel).fetchone()
+        final = db.execute(sel).first()
         self.assert_(final['counter_value'] == iterations * thread_count)
 
     def overlap(self, ids, errors, update_style):
index 8df449718e22369ef198672c3874186e03f9a9a8..4a5775218dcaa5f1e67a88aaf61e821e9df54e3e 100644 (file)
@@ -1,6 +1,5 @@
 from sqlalchemy.test.testing import eq_, assert_raises
 import copy
-import gc
 import pickle
 
 from sqlalchemy import *
@@ -8,6 +7,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm.collections import collection
 from sqlalchemy.ext.associationproxy import *
 from sqlalchemy.test import *
+from sqlalchemy.test.util import gc_collect
 
 
 class DictCollection(dict):
@@ -880,7 +880,7 @@ class ReconstitutionTest(TestBase):
 
 
         add_child('p1', 'c1')
-        gc.collect()
+        gc_collect()
         add_child('p1', 'c2')
 
         session.flush()
@@ -895,7 +895,7 @@ class ReconstitutionTest(TestBase):
         p.kids.extend(['c1', 'c2'])
         p_copy = copy.copy(p)
         del p
-        gc.collect()
+        gc_collect()
 
         assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
 
index ce2549099822d26eaedc5a26019da5729d71ca7f..3ee94d0271b15027b8b6fda14fa8587de8602e48 100644 (file)
@@ -1,5 +1,7 @@
 from sqlalchemy import *
+from sqlalchemy.types import TypeEngine
 from sqlalchemy.sql.expression import ClauseElement, ColumnClause
+from sqlalchemy.schema import DDLElement
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import table, column
 from sqlalchemy.test import *
@@ -25,7 +27,35 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5),
             "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1"
         )
+    
+    def test_types(self):
+        class MyType(TypeEngine):
+            pass
+        
+        @compiles(MyType, 'sqlite')
+        def visit_type(type, compiler, **kw):
+            return "SQLITE_FOO"
+
+        @compiles(MyType, 'postgresql')
+        def visit_type(type, compiler, **kw):
+            return "POSTGRES_FOO"
+
+        from sqlalchemy.dialects.sqlite import base as sqlite
+        from sqlalchemy.dialects.postgresql import base as postgresql
 
+        self.assert_compile(
+            MyType(),
+            "SQLITE_FOO",
+            dialect=sqlite.dialect()
+        )
+
+        self.assert_compile(
+            MyType(),
+            "POSTGRES_FOO",
+            dialect=postgresql.dialect()
+        )
+        
+        
     def test_stateful(self):
         class MyThingy(ColumnClause):
             def __init__(self):
@@ -71,10 +101,10 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
         )
 
     def test_dialect_specific(self):
-        class AddThingy(ClauseElement):
+        class AddThingy(DDLElement):
             __visit_name__ = 'add_thingy'
 
-        class DropThingy(ClauseElement):
+        class DropThingy(DDLElement):
             __visit_name__ = 'drop_thingy'
 
         @compiles(AddThingy, 'sqlite')
@@ -97,7 +127,7 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             "DROP THINGY"
         )
 
-        from sqlalchemy.databases import sqlite as base
+        from sqlalchemy.dialects.sqlite import base
         self.assert_compile(AddThingy(),
             "ADD SPECIAL SL THINGY",
             dialect=base.dialect()
index 224f41731a9d696ce22ce3b171735745e2fc8ba2..745e3b7cf8cab3169a8bf7ae6472e11518a1e7b5 100644 (file)
@@ -5,8 +5,7 @@ from sqlalchemy import exc
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred
 from sqlalchemy.test.testing import eq_
 
@@ -27,14 +26,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", backref="user")
 
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column(String(50), key='_email')
             user_id = Column('user_id', Integer, ForeignKey('users.id'),
                              key='_user_id')
@@ -127,7 +126,7 @@ class DeclarativeTest(DeclarativeTestBase):
         
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column(String(50))
             addresses = relation("Address", order_by="desc(Address.email)", 
                 primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]",
@@ -136,7 +135,7 @@ class DeclarativeTest(DeclarativeTestBase):
         
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column(String(50))
             user_id = Column(Integer)  # note no foreign key
         
@@ -180,13 +179,13 @@ class DeclarativeTest(DeclarativeTestBase):
     def test_uncompiled_attributes_in_relation(self):
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column(String(50))
             user_id = Column(Integer, ForeignKey('users.id'))
 
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column(String(50))
             addresses = relation("Address", order_by=Address.email, 
                 foreign_keys=Address.user_id, 
@@ -272,14 +271,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
         User.name = Column('name', String(50))
         User.addresses = relation("Address", backref="user")
 
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
         Address.email = Column(String(50), key='_email')
         Address.user_id = Column('user_id', Integer, ForeignKey('users.id'),
                              key='_user_id')
@@ -312,14 +311,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column('email', String(50))
             user_id = Column('user_id', Integer, ForeignKey('users.id'))
 
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", order_by=Address.email)
 
@@ -341,14 +340,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column('email', String(50))
             user_id = Column('user_id', Integer, ForeignKey('users.id'))
 
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", order_by=(Address.email, Address.id))
 
@@ -368,14 +367,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", backref="user")
 
         class Address(ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column('email', String(50))
             user_id = Column('user_id', Integer, ForeignKey('users.id'))
         
@@ -478,14 +477,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", backref="user")
 
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column('email', String(50))
             user_id = Column('user_id', Integer, ForeignKey('users.id'))
 
@@ -513,7 +512,7 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
 
         User.a = Column('a', String(10))
@@ -535,14 +534,14 @@ class DeclarativeTest(DeclarativeTestBase):
     def test_column_properties(self):
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column(String(50))
             user_id = Column(Integer, ForeignKey('users.id'))
 
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             adr_count = sa.orm.column_property(
                 sa.select([sa.func.count(Address.id)], Address.user_id == id).
@@ -588,7 +587,7 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             name = sa.orm.deferred(Column(String(50)))
 
         Base.metadata.create_all()
@@ -607,7 +606,7 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             _name = Column('name', String(50))
             def _set_name(self, name):
                 self._name = "SOMENAME " + name
@@ -636,7 +635,7 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             _name = Column('name', String(50))
             name = sa.orm.synonym('_name', comparator_factory=CustomCompare)
         
@@ -652,7 +651,7 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             _name = Column('name', String(50))
             def _set_name(self, name):
                 self._name = "SOMENAME " + name
@@ -674,14 +673,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", backref="user")
 
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column('email', String(50))
             user_id = Column('user_id', Integer, ForeignKey(User.id))
 
@@ -711,14 +710,14 @@ class DeclarativeTest(DeclarativeTestBase):
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             email = Column('email', String(50))
             user_id = Column('user_id', Integer, ForeignKey('users.id'))
 
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
 
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             addresses = relation("Address", backref="user",
                                  primaryjoin=id == Address.user_id)
@@ -763,7 +762,7 @@ class DeclarativeTest(DeclarativeTestBase):
     def test_with_explicit_autoloaded(self):
         meta = MetaData(testing.db)
         t1 = Table('t1', meta,
-                   Column('id', String(50), primary_key=True),
+                   Column('id', String(50), primary_key=True, test_needs_autoincrement=True),
                    Column('data', String(50)))
         meta.create_all()
         try:
@@ -779,6 +778,70 @@ class DeclarativeTest(DeclarativeTestBase):
         finally:
             meta.drop_all()
 
+    def test_synonym_for(self):
+        class User(Base, ComparableEntity):
+            __tablename__ = 'users'
+
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+            name = Column('name', String(50))
+
+            @decl.synonym_for('name')
+            @property
+            def namesyn(self):
+                return self.name
+
+        Base.metadata.create_all()
+
+        sess = create_session()
+        u1 = User(name='someuser')
+        eq_(u1.name, "someuser")
+        eq_(u1.namesyn, 'someuser')
+        sess.add(u1)
+        sess.flush()
+
+        rt = sess.query(User).filter(User.namesyn == 'someuser').one()
+        eq_(rt, u1)
+
+    def test_comparable_using(self):
+        class NameComparator(sa.orm.PropComparator):
+            @property
+            def upperself(self):
+                cls = self.prop.parent.class_
+                col = getattr(cls, 'name')
+                return sa.func.upper(col)
+
+            def operate(self, op, other, **kw):
+                return op(self.upperself, other, **kw)
+
+        class User(Base, ComparableEntity):
+            __tablename__ = 'users'
+
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+            name = Column('name', String(50))
+
+            @decl.comparable_using(NameComparator)
+            @property
+            def uc_name(self):
+                return self.name is not None and self.name.upper() or None
+
+        Base.metadata.create_all()
+
+        sess = create_session()
+        u1 = User(name='someuser')
+        eq_(u1.name, "someuser", u1.name)
+        eq_(u1.uc_name, 'SOMEUSER', u1.uc_name)
+        sess.add(u1)
+        sess.flush()
+        sess.expunge_all()
+
+        rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one()
+        eq_(rt, u1)
+        sess.expunge_all()
+
+        rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one()
+        eq_(rt, u1)
+
+
 class DeclarativeInheritanceTest(DeclarativeTestBase):
     def test_custom_join_condition(self):
         class Foo(Base):
@@ -797,13 +860,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
     def test_joined(self):
         class Company(Base, ComparableEntity):
             __tablename__ = 'companies'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             employees = relation("Person")
 
         class Person(Base, ComparableEntity):
             __tablename__ = 'people'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             company_id = Column('company_id', Integer,
                                 ForeignKey('companies.id'))
             name = Column('name', String(50))
@@ -911,13 +974,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
         
         class Company(Base, ComparableEntity):
             __tablename__ = 'companies'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             employees = relation("Person")
 
         class Person(Base, ComparableEntity):
             __tablename__ = 'people'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             company_id = Column('company_id', Integer,
                                 ForeignKey('companies.id'))
             name = Column('name', String(50))
@@ -967,13 +1030,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
         class Company(Base, ComparableEntity):
             __tablename__ = 'companies'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             employees = relation("Person")
 
         class Person(Base, ComparableEntity):
             __tablename__ = 'people'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             company_id = Column(Integer,
                                 ForeignKey('companies.id'))
             name = Column(String(50))
@@ -1037,13 +1100,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
     def test_joined_from_single(self):
         class Company(Base, ComparableEntity):
             __tablename__ = 'companies'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column('name', String(50))
             employees = relation("Person")
         
         class Person(Base, ComparableEntity):
             __tablename__ = 'people'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             company_id = Column(Integer, ForeignKey('companies.id'))
             name = Column(String(50))
             discriminator = Column('type', String(50))
@@ -1100,7 +1163,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
     def test_add_deferred(self):
         class Person(Base, ComparableEntity):
             __tablename__ = 'people'
-            id = Column('id', Integer, primary_key=True)
+            id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
 
         Person.name = deferred(Column(String(10)))
 
@@ -1117,6 +1180,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
                 Person(name='ratbert')
             ]
         )
+        sess.expunge_all()
+
         person = sess.query(Person).filter(Person.name == 'ratbert').one()
         assert 'name' not in person.__dict__
 
@@ -1127,7 +1192,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
         class Person(Base, ComparableEntity):
             __tablename__ = 'people'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column(String(50))
             discriminator = Column('type', String(50))
             __mapper_args__ = {'polymorphic_on':discriminator}
@@ -1139,7 +1204,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
             
         class Language(Base, ComparableEntity):
             __tablename__ = 'languages'
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
             name = Column(String(50))
 
         assert not hasattr(Person, 'primary_language_id')
@@ -1236,12 +1301,12 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
         
     def test_concrete(self):
         engineers = Table('engineers', Base.metadata,
-                        Column('id', Integer, primary_key=True),
+                        Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                         Column('name', String(50)),
                         Column('primary_language', String(50))
                     )
         managers = Table('managers', Base.metadata,
-                        Column('id', Integer, primary_key=True),
+                        Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                         Column('name', String(50)),
                         Column('golf_swing', String(50))
                     )
@@ -1293,12 +1358,12 @@ def _produce_test(inline, stringbased):
 
             class User(Base, ComparableEntity):
                 __tablename__ = 'users'
-                id = Column(Integer, primary_key=True)
+                id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
                 name = Column(String(50))
             
             class Address(Base, ComparableEntity):
                 __tablename__ = 'addresses'
-                id = Column(Integer, primary_key=True)
+                id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
                 email = Column(String(50))
                 user_id = Column(Integer, ForeignKey('users.id'))
                 if inline:
@@ -1363,16 +1428,16 @@ class DeclarativeReflectionTest(testing.TestBase):
         reflection_metadata = MetaData(testing.db)
 
         Table('users', reflection_metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(50)),
               test_needs_fk=True)
         Table('addresses', reflection_metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('email', String(50)),
               Column('user_id', Integer, ForeignKey('users.id')),
               test_needs_fk=True)
         Table('imhandles', reflection_metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('user_id', Integer),
               Column('network', String(50)),
               Column('handle', String(50)),
@@ -1398,12 +1463,17 @@ class DeclarativeReflectionTest(testing.TestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
             __autoload__ = True
+            if testing.against('oracle', 'firebird'):
+                id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             addresses = relation("Address", backref="user")
 
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
             __autoload__ = True
 
+            if testing.against('oracle', 'firebird'):
+                id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+        
         u1 = User(name='u1', addresses=[
             Address(email='one'),
             Address(email='two'),
@@ -1428,12 +1498,16 @@ class DeclarativeReflectionTest(testing.TestBase):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
             __autoload__ = True
+            if testing.against('oracle', 'firebird'):
+                id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             nom = Column('name', String(50), key='nom')
             addresses = relation("Address", backref="user")
 
         class Address(Base, ComparableEntity):
             __tablename__ = 'addresses'
             __autoload__ = True
+            if testing.against('oracle', 'firebird'):
+                id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
 
         u1 = User(nom='u1', addresses=[
             Address(email='one'),
@@ -1461,12 +1535,16 @@ class DeclarativeReflectionTest(testing.TestBase):
         class IMHandle(Base, ComparableEntity):
             __tablename__ = 'imhandles'
             __autoload__ = True
+            if testing.against('oracle', 'firebird'):
+                id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
 
             user_id = Column('user_id', Integer,
                              ForeignKey('users.id'))
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
             __autoload__ = True
+            if testing.against('oracle', 'firebird'):
+                id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
             handles = relation("IMHandle", backref="user")
 
         u1 = User(name='u1', handles=[
@@ -1487,69 +1565,3 @@ class DeclarativeReflectionTest(testing.TestBase):
         eq_(a1, IMHandle(network='lol', handle='zomg'))
         eq_(a1.user, User(name='u1'))
 
-    def test_synonym_for(self):
-        class User(Base, ComparableEntity):
-            __tablename__ = 'users'
-
-            id = Column('id', Integer, primary_key=True)
-            name = Column('name', String(50))
-
-            @decl.synonym_for('name')
-            @property
-            def namesyn(self):
-                return self.name
-
-        Base.metadata.create_all()
-
-        sess = create_session()
-        u1 = User(name='someuser')
-        eq_(u1.name, "someuser")
-        eq_(u1.namesyn, 'someuser')
-        sess.add(u1)
-        sess.flush()
-
-        rt = sess.query(User).filter(User.namesyn == 'someuser').one()
-        eq_(rt, u1)
-
-    def test_comparable_using(self):
-        class NameComparator(sa.orm.PropComparator):
-            @property
-            def upperself(self):
-                cls = self.prop.parent.class_
-                col = getattr(cls, 'name')
-                return sa.func.upper(col)
-
-            def operate(self, op, other, **kw):
-                return op(self.upperself, other, **kw)
-
-        class User(Base, ComparableEntity):
-            __tablename__ = 'users'
-
-            id = Column('id', Integer, primary_key=True)
-            name = Column('name', String(50))
-
-            @decl.comparable_using(NameComparator)
-            @property
-            def uc_name(self):
-                return self.name is not None and self.name.upper() or None
-
-        Base.metadata.create_all()
-
-        sess = create_session()
-        u1 = User(name='someuser')
-        eq_(u1.name, "someuser", u1.name)
-        eq_(u1.uc_name, 'SOMEUSER', u1.uc_name)
-        sess.add(u1)
-        sess.flush()
-        sess.expunge_all()
-
-        rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one()
-        eq_(rt, u1)
-        sess.expunge_all()
-
-        rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one()
-        eq_(rt, u1)
-
-
-if __name__ == '__main__':
-    testing.main()
index b8a8e3fef9d27dba92dfca42abfc8e823ada895b..c400797b0ecb519bbc4c2ed3ce66ea889c0578ba 100644 (file)
@@ -96,8 +96,6 @@ class SerializeTest(MappedTest):
             [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')]
         )
     
-    # fails due to pure Python pickle bug:  http://bugs.python.org/issue998998
-    @testing.fails_if(lambda: util.py3k) 
     def test_query(self):
         q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses))
         eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])])
index 8d695e912b99b505eba8504ac62f8d594080886e..f08d253d57666adceeaadc52e14d01dd3fd85237 100644 (file)
@@ -1,10 +1,11 @@
-import gc
 import inspect
 import sys
 import types
 import sqlalchemy as sa
+import sqlalchemy.exceptions as sa_exc
 from sqlalchemy.test import config, testing
 from sqlalchemy.test.testing import resolve_artifact_names, adict
+from sqlalchemy.test.engines import drop_all_tables
 from sqlalchemy.util import function_named
 
 
@@ -74,20 +75,19 @@ class ComparableEntity(BasicEntity):
                 if attr.startswith('_'):
                     continue
                 value = getattr(a, attr)
-                if (hasattr(value, '__iter__') and
-                    not isinstance(value, basestring)):
-                    try:
-                        # catch AttributeError so that lazy loaders trigger
-                        battr = getattr(b, attr)
-                    except AttributeError:
-                        return False
 
+                try:
+                    # handle lazy loader errors
+                    battr = getattr(b, attr)
+                except (AttributeError, sa_exc.UnboundExecutionError):
+                    return False
+
+                if hasattr(value, '__iter__'):
                     if list(value) != list(battr):
                         return False
                 else:
-                    if value is not None:
-                        if value != getattr(b, attr, None):
-                            return False
+                    if value is not None and value != battr:
+                        return False
             return True
         finally:
             _recursion_stack.remove(id(self))
@@ -173,7 +173,7 @@ class MappedTest(ORMTest):
     def setup(self):
         if self.run_define_tables == 'each':
             self.tables.clear()
-            self.metadata.drop_all()
+            drop_all_tables(self.metadata)
             self.metadata.clear()
             self.define_tables(self.metadata)
             self.metadata.create_all()
@@ -217,7 +217,7 @@ class MappedTest(ORMTest):
         for cl in cls.classes.values():
             cls.unregister_class(cl)
         ORMTest.teardown_class()
-        cls.metadata.drop_all()
+        drop_all_tables(cls.metadata)
         cls.metadata.bind = None
 
     @classmethod
index 931d8cadf87f7a8433dba39c16712727254948e3..e9d6ac16565c4a55efa17471cae2a27a99ab6405 100644 (file)
@@ -60,7 +60,7 @@ email_bounces = fixture_table(
 
 orders = fixture_table(
     Table('orders', fixture_metadata,
-          Column('id', Integer, primary_key=True),
+          Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
           Column('user_id', None, ForeignKey('users.id')),
           Column('address_id', None, ForeignKey('addresses.id')),
           Column('description', String(30)),
@@ -76,7 +76,7 @@ orders = fixture_table(
 
 dingalings = fixture_table(
     Table("dingalings", fixture_metadata,
-          Column('id', Integer, primary_key=True),
+          Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
           Column('address_id', None, ForeignKey('addresses.id')),
           Column('data', String(30)),
           test_needs_acid=True,
index 4e55cf70eaa906f81bb71e3e831693c3d5a93a20..f6d5111b2c2cc31396b8b2a6e9f251db2b5f9ec5 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
 
 from sqlalchemy.test import testing
+from sqlalchemy.test.schema import Table, Column
 from test.orm import _base
 
 
@@ -15,7 +16,7 @@ def produce_test(parent, child, direction):
         def define_tables(cls, metadata):
             global ta, tb, tc
             ta = ["a", metadata]
-            ta.append(Column('id', Integer, primary_key=True)),
+            ta.append(Column('id', Integer, primary_key=True, test_needs_autoincrement=True)),
             ta.append(Column('a_data', String(30)))
             if "a"== parent and direction == MANYTOONE:
                 ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo")))
index 8cad8ed78163f1328290f35c77eb9c83d0d35492..2dab59bb25d24f4683c72d3b560c110ba31741ce 100644 (file)
@@ -4,13 +4,14 @@ from sqlalchemy.orm import *
 
 from sqlalchemy.util import function_named
 from test.orm import _base, _fixtures
+from sqlalchemy.test.schema import Table, Column
 
 class ABCTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         global a, b, c
         a = Table('a', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('adata', String(30)),
             Column('type', String(30)),
             )
@@ -61,7 +62,7 @@ class ABCTest(_base.MappedTest):
                 C(cdata='c1', bdata='c1', adata='c1'),
                 C(cdata='c2', bdata='c2', adata='c2'),
                 C(cdata='c2', bdata='c2', adata='c2'),
-            ] == sess.query(A).all()
+            ] == sess.query(A).order_by(A.id).all()
 
             assert [
                 B(bdata='b1', adata='b1'),
index bad6920de7ad2a16716097fe7fb92711144309cd..b2e00de3598261c24b53847425964b128ea35cdb 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.test import testing, engines
 from sqlalchemy.util import function_named
 from test.orm import _base, _fixtures
+from sqlalchemy.test.schema import Table, Column
 
 class O2MTest(_base.MappedTest):
     """deals with inheritance and one-to-many relationships"""
@@ -14,8 +15,7 @@ class O2MTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global foo, bar, blub
         foo = Table('foo', metadata,
-            Column('id', Integer, Sequence('foo_seq', optional=True),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(20)))
 
         bar = Table('bar', metadata,
@@ -73,9 +73,8 @@ class FalseDiscriminatorTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global t1
         t1 = Table('t1', metadata, 
-                    Column('id', Integer, primary_key=True), 
-                    Column('type', Boolean, nullable=False)
-                )
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True), 
+            Column('type', Boolean, nullable=False))
         
     def test_false_on_sub(self):
         class Foo(object):pass
@@ -108,7 +107,7 @@ class PolymorphicSynonymTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global t1, t2
         t1 = Table('t1', metadata,
-                   Column('id', Integer, primary_key=True),
+                   Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                    Column('type', String(10), nullable=False),
                    Column('info', String(255)))
         t2 = Table('t2', metadata,
@@ -149,12 +148,12 @@ class CascadeTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global t1, t2, t3, t4
         t1= Table('t1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30))
             )
 
         t2 = Table('t2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('t1id', Integer, ForeignKey('t1.id')),
             Column('type', String(30)),
             Column('data', String(30))
@@ -164,7 +163,7 @@ class CascadeTest(_base.MappedTest):
             Column('moredata', String(30)))
 
         t4 = Table('t4', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('t3id', Integer, ForeignKey('t3.id')),
             Column('data', String(30)))
 
@@ -214,8 +213,7 @@ class GetTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global foo, bar, blub
         foo = Table('foo', metadata,
-            Column('id', Integer, Sequence('foo_seq', optional=True),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('type', String(30)),
             Column('data', String(20)))
 
@@ -224,7 +222,7 @@ class GetTest(_base.MappedTest):
             Column('data', String(20)))
 
         blub = Table('blub', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('foo_id', Integer, ForeignKey('foo.id')),
             Column('bar_id', Integer, ForeignKey('bar.id')),
             Column('data', String(20)))
@@ -304,8 +302,7 @@ class EagerLazyTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global foo, bar, bar_foo
         foo = Table('foo', metadata,
-                    Column('id', Integer, Sequence('foo_seq', optional=True),
-                           primary_key=True),
+                    Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                     Column('data', String(30)))
         bar = Table('bar', metadata,
                     Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
@@ -350,13 +347,13 @@ class FlushTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global users, roles, user_roles, admins
         users = Table('users', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('email', String(128)),
             Column('password', String(16)),
         )
 
         roles = Table('role', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('description', String(32))
         )
 
@@ -366,7 +363,7 @@ class FlushTest(_base.MappedTest):
         )
 
         admins = Table('admin', metadata,
-            Column('admin_id', Integer, primary_key=True),
+            Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer, ForeignKey('users.id'))
         )
 
@@ -439,7 +436,7 @@ class VersioningTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global base, subtable, stuff
         base = Table('base', metadata,
-            Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('version_id', Integer, nullable=False),
             Column('value', String(40)),
             Column('discriminator', Integer, nullable=False)
@@ -449,11 +446,10 @@ class VersioningTest(_base.MappedTest):
             Column('subdata', String(50))
             )
         stuff = Table('stuff', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('parent', Integer, ForeignKey('base.id'))
             )
 
-    @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
     @engines.close_open_connections
     def test_save_update(self):
         class Base(_fixtures.Base):
@@ -493,16 +489,16 @@ class VersioningTest(_base.MappedTest):
 
         try:
             sess2.flush()
-            assert False
+            assert not testing.db.dialect.supports_sane_rowcount
         except orm_exc.ConcurrentModificationError, e:
             assert True
 
         sess2.refresh(s2)
-        assert s2.subdata == 'sess1 subdata'
+        if testing.db.dialect.supports_sane_rowcount:
+            assert s2.subdata == 'sess1 subdata'
         s2.subdata = 'sess2 subdata'
         sess2.flush()
 
-    @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
     def test_delete(self):
         class Base(_fixtures.Base):
             pass
@@ -534,7 +530,7 @@ class VersioningTest(_base.MappedTest):
         try:
             s1.subdata = 'some new subdata'
             sess.flush()
-            assert False
+            assert not testing.db.dialect.supports_sane_rowcount
         except orm_exc.ConcurrentModificationError, e:
             assert True
 
@@ -550,12 +546,12 @@ class DistinctPKTest(_base.MappedTest):
         global person_table, employee_table, Person, Employee
 
         person_table = Table("persons", metadata,
-                Column("id", Integer, primary_key=True),
+                Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
                 Column("name", String(80)),
                 )
 
         employee_table = Table("employees", metadata,
-                Column("id", Integer, primary_key=True),
+                Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
                 Column("salary", Integer),
                 Column("person_id", Integer, ForeignKey("persons.id")),
                 )
@@ -623,7 +619,7 @@ class SyncCompileTest(_base.MappedTest):
         global _a_table, _b_table, _c_table
 
         _a_table = Table('a', metadata,
-           Column('id', Integer, primary_key=True),
+           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('data1', String(128))
         )
 
@@ -691,7 +687,7 @@ class OverrideColKeyTest(_base.MappedTest):
         global base, subtable
         
         base = Table('base', metadata, 
-            Column('base_id', Integer, primary_key=True),
+            Column('base_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(255)),
             Column('sqlite_fixer', String(10))
             )
@@ -921,7 +917,7 @@ class OptimizedLoadTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global base, sub
         base = Table('base', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)),
             Column('type', String(50))
         )
@@ -1008,7 +1004,7 @@ class PKDiscriminatorTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         parents = Table('parents', metadata,
-                           Column('id', Integer, primary_key=True),
+                           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                            Column('name', String(60)))
                            
         children = Table('children', metadata,
@@ -1061,14 +1057,14 @@ class DeleteOrphanTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global single, parent
         single = Table('single', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('type', String(50), nullable=False),
             Column('data', String(50)),
             Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False),
             )
             
         parent = Table('parent', metadata,
-                Column('id', Integer, primary_key=True),
+                Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                 Column('data', String(50))
             )
     
index 46bd171e4405286b7ffca8d226c2762559890dbb..3a78be9d7bc9c60ee05466cee0ac77a2716df2ae 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy.test import testing
 from test.orm import _base
 from sqlalchemy.orm import attributes
 from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.schema import Table, Column
 
 class Employee(object):
     def __init__(self, name):
@@ -48,31 +49,31 @@ class ConcreteTest(_base.MappedTest):
         global managers_table, engineers_table, hackers_table, companies, employees_table
 
         companies = Table('companies', metadata,
-           Column('id', Integer, primary_key=True),
+           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)))
 
         employees_table = Table('employees', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('company_id', Integer, ForeignKey('companies.id'))
         )
         
         managers_table = Table('managers', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('manager_data', String(50)),
             Column('company_id', Integer, ForeignKey('companies.id'))
         )
 
         engineers_table = Table('engineers', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('engineer_info', String(50)),
             Column('company_id', Integer, ForeignKey('companies.id'))
         )
 
         hackers_table = Table('hackers', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('engineer_info', String(50)),
             Column('company_id', Integer, ForeignKey('companies.id')),
@@ -320,17 +321,17 @@ class PropertyInheritanceTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('a_table', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('some_c_id', Integer, ForeignKey('c_table.id')),
             Column('aname', String(50)),
         )
         Table('b_table', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('some_c_id', Integer, ForeignKey('c_table.id')),
             Column('bname', String(50)),
         )
         Table('c_table', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('cname', String(50)),
             
         )
@@ -525,11 +526,11 @@ class ColKeysTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global offices_table, refugees_table
         refugees_table = Table('refugee', metadata,
-           Column('refugee_fid', Integer, primary_key=True),
+           Column('refugee_fid', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('refugee_name', Unicode(30), key='name'))
 
         offices_table = Table('office', metadata,
-           Column('office_fid', Integer, primary_key=True),
+           Column('office_fid', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('office_name', Unicode(30), key='name'))
     
     @classmethod
index 06730125113c505901fa640ae580de81ede529cb..f94781c278c41d839e449e3667cb00da01f2f760 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.test import testing
 from sqlalchemy.util import function_named
 from test.orm import _base
+from sqlalchemy.test.schema import Table, Column
 
 class BaseObject(object):
     def __init__(self, *args, **kwargs):
@@ -75,49 +76,48 @@ class MagazineTest(_base.MappedTest):
         global publication_table, issue_table, location_table, location_name_table, magazine_table, \
         page_table, magazine_page_table, classified_page_table, page_size_table
 
-        zerodefault = {} #{'default':0}
         publication_table = Table('publication', metadata,
-            Column('id', Integer, primary_key=True, default=None),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(45), default=''),
         )
         issue_table = Table('issue', metadata,
-            Column('id', Integer, primary_key=True, default=None),
-            Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault),
-            Column('issue', Integer, **zerodefault),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('publication_id', Integer, ForeignKey('publication.id')),
+            Column('issue', Integer),
         )
         location_table = Table('location', metadata,
-            Column('id', Integer, primary_key=True, default=None),
-            Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('issue_id', Integer, ForeignKey('issue.id')),
             Column('ref', CHAR(3), default=''),
-            Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault),
+            Column('location_name_id', Integer, ForeignKey('location_name.id')),
         )
         location_name_table = Table('location_name', metadata,
-            Column('id', Integer, primary_key=True, default=None),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(45), default=''),
         )
         magazine_table = Table('magazine', metadata,
-            Column('id', Integer, primary_key=True, default=None),
-            Column('location_id', Integer, ForeignKey('location.id'), **zerodefault),
-            Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('location_id', Integer, ForeignKey('location.id')),
+            Column('page_size_id', Integer, ForeignKey('page_size.id')),
         )
         page_table = Table('page', metadata,
-            Column('id', Integer, primary_key=True, default=None),
-            Column('page_no', Integer, **zerodefault),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('page_no', Integer),
             Column('type', CHAR(1), default='p'),
         )
         magazine_page_table = Table('magazine_page', metadata,
-            Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault),
-            Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault),
-            Column('orders', TEXT, default=''),
+            Column('page_id', Integer, ForeignKey('page.id'), primary_key=True),
+            Column('magazine_id', Integer, ForeignKey('magazine.id')),
+            Column('orders', Text, default=''),
         )
         classified_page_table = Table('classified_page', metadata,
-            Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault),
+            Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True),
             Column('titles', String(45), default=''),
         )
         page_size_table = Table('page_size', metadata,
-            Column('id', Integer, primary_key=True, default=None),
-            Column('width', Integer, **zerodefault),
-            Column('height', Integer, **zerodefault),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('width', Integer),
+            Column('height', Integer),
             Column('name', String(45), default=''),
         )
 
@@ -176,10 +176,11 @@ def generate_round_trip_test(use_unions=False, use_joins=False):
                 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no))
             })
 
-        classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id])
-        #compile_mappers()
-        #print [str(s) for s in classified_page_mapper.primary_key]
-        #print classified_page_mapper.columntoproperty[page_table.c.id]
+        classified_page_mapper = mapper(ClassifiedPage, 
+                                    classified_page_table, 
+                                    inherits=magazine_page_mapper, 
+                                    polymorphic_identity='c', 
+                                    primary_key=[page_table.c.id])
 
 
         session = create_session()
index f7e676bbbcf76d1b3b3a9692d33222269f499cbc..7b6ad04eb20f3c7fe925e1d6009efbb2ef897e77 100644 (file)
@@ -194,11 +194,11 @@ class InheritTest3(_base.MappedTest):
         b.foos.append(Foo("foo #1"))
         b.foos.append(Foo("foo #2"))
         sess.flush()
-        compare = repr(b) + repr(sorted([repr(o) for o in b.foos]))
+        compare = [repr(b)] + sorted([repr(o) for o in b.foos])
         sess.expunge_all()
         l = sess.query(Bar).all()
         print repr(l[0]) + repr(l[0].foos)
-        found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos]))
+        found = [repr(l[0])] + sorted([repr(o) for o in l[0].foos])
         eq_(found, compare)
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
index 67b543f31c6ab6255fe467a5390a3e8b04a0c39d..e434218b9ce0e39078f69c56cd720e23e30b73fb 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.orm import *
 
 from test.orm import _base
 from sqlalchemy.test import testing
+from sqlalchemy.test.schema import Table, Column
 
 
 class PolymorphicCircularTest(_base.MappedTest):
@@ -12,7 +13,7 @@ class PolymorphicCircularTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global Table1, Table1B, Table2, Table3,  Data
         table1 = Table('table1', metadata,
-                       Column('id', Integer, primary_key=True),
+                       Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                        Column('related_id', Integer, ForeignKey('table1.id'), nullable=True),
                        Column('type', String(30)),
                        Column('name', String(30))
@@ -27,7 +28,7 @@ class PolymorphicCircularTest(_base.MappedTest):
                       )
 
         data = Table('data', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('node_id', Integer, ForeignKey('table1.id')),
             Column('data', String(30))
             )
@@ -72,7 +73,7 @@ class PolymorphicCircularTest(_base.MappedTest):
                                    polymorphic_on=table1.c.type,
                                    polymorphic_identity='table1',
                                    properties={
-                                    'next': relation(Table1,
+                                    'nxt': relation(Table1,
                                         backref=backref('prev', foreignkey=join.c.id, uselist=False),
                                         uselist=False, primaryjoin=join.c.id==join.c.related_id),
                                     'data':relation(mapper(Data, data))
@@ -86,15 +87,16 @@ class PolymorphicCircularTest(_base.MappedTest):
 
         # currently, the "eager" relationships degrade to lazy relationships
         # due to the polymorphic load.
-        # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential"
+        # the "nxt" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential"
         # exception now.  since eager loading would never work for that relation anyway, its better that the user
         # gets an exception instead of it silently not eager loading.
+        # NOTE: using "nxt" instead of "next" to avoid 2to3 turning it into __next__() for some reason.
         table1_mapper = mapper(Table1, table1,
                                #select_table=join,
                                polymorphic_on=table1.c.type,
                                polymorphic_identity='table1',
                                properties={
-                               'next': relation(Table1,
+                               'nxt': relation(Table1,
                                    backref=backref('prev', remote_side=table1.c.id, uselist=False),
                                    uselist=False, primaryjoin=table1.c.id==table1.c.related_id),
                                'data':relation(mapper(Data, data), lazy=False, order_by=data.c.id)
@@ -147,7 +149,7 @@ class PolymorphicCircularTest(_base.MappedTest):
             else:
                 newobj = c
             if obj is not None:
-                obj.next = newobj
+                obj.nxt = newobj
             else:
                 t = newobj
             obj = newobj
@@ -161,7 +163,7 @@ class PolymorphicCircularTest(_base.MappedTest):
         node = t
         while (node):
             assertlist.append(node)
-            n = node.next
+            n = node.nxt
             if n is not None:
                 assert n.prev is node
             node = n
@@ -174,7 +176,7 @@ class PolymorphicCircularTest(_base.MappedTest):
         assertlist = []
         while (node):
             assertlist.append(node)
-            n = node.next
+            n = node.nxt
             if n is not None:
                 assert n.prev is node
             node = n
@@ -188,7 +190,7 @@ class PolymorphicCircularTest(_base.MappedTest):
             assertlist.insert(0, node)
             n = node.prev
             if n is not None:
-                assert n.next is node
+                assert n.nxt is node
             node = n
         backwards = repr(assertlist)
 
index 51b6d4970a5a92c6c55e623b69e5bb3e39967b50..80c14413a0c4c738e6874232f4279e6565a6418f 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.test import TestBase, AssertsExecutionResults, testing
 from sqlalchemy.util import function_named
 from test.orm import _base, _fixtures
 from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.schema import Table, Column
 
 class AttrSettable(object):
     def __init__(self, **kwargs):
@@ -105,7 +106,7 @@ class RelationTest2(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, managers, data
         people = Table('people', metadata,
-           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('type', String(30)))
 
@@ -201,7 +202,7 @@ class RelationTest3(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, managers, data
         people = Table('people', metadata,
-           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('colleague_id', Integer, ForeignKey('people.person_id')),
            Column('name', String(50)),
            Column('type', String(30)))
@@ -307,7 +308,7 @@ class RelationTest4(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, engineers, managers, cars
         people = Table('people', metadata,
-           Column('person_id', Integer, primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)))
 
         engineers = Table('engineers', metadata,
@@ -319,7 +320,7 @@ class RelationTest4(_base.MappedTest):
            Column('longer_status', String(70)))
 
         cars = Table('cars', metadata,
-           Column('car_id', Integer, primary_key=True),
+           Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('owner', Integer, ForeignKey('people.person_id')))
 
     def testmanytoonepolymorphic(self):
@@ -420,7 +421,7 @@ class RelationTest5(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, engineers, managers, cars
         people = Table('people', metadata,
-           Column('person_id', Integer, primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('type', String(50)))
 
@@ -433,7 +434,7 @@ class RelationTest5(_base.MappedTest):
            Column('longer_status', String(70)))
 
         cars = Table('cars', metadata,
-           Column('car_id', Integer, primary_key=True),
+           Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('owner', Integer, ForeignKey('people.person_id')))
 
     def testeagerempty(self):
@@ -482,7 +483,7 @@ class RelationTest6(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, managers, data
         people = Table('people', metadata,
-           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            )
 
@@ -525,14 +526,14 @@ class RelationTest7(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, engineers, managers, cars, offroad_cars
         cars = Table('cars', metadata,
-                Column('car_id', Integer, primary_key=True),
+                Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
                 Column('name', String(30)))
 
         offroad_cars = Table('offroad_cars', metadata,
                 Column('car_id',Integer, ForeignKey('cars.car_id'),nullable=False,primary_key=True))
 
         people = Table('people', metadata,
-                Column('person_id', Integer, primary_key=True),
+                Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
                 Column('car_id', Integer, ForeignKey('cars.car_id'), nullable=False),
                 Column('name', String(50)))
 
@@ -625,7 +626,7 @@ class RelationTest8(_base.MappedTest):
     def define_tables(cls, metadata):
         global taggable, users
         taggable = Table('taggable', metadata,
-                         Column('id', Integer, primary_key=True),
+                         Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                          Column('type', String(30)),
                          Column('owner_id', Integer, ForeignKey('taggable.id')),
                          )
@@ -680,11 +681,11 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
         metadata = MetaData(testing.db)
         # table definitions
         status = Table('status', metadata,
-           Column('status_id', Integer, primary_key=True),
+           Column('status_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(20)))
 
         people = Table('people', metadata,
-           Column('person_id', Integer, primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False),
            Column('name', String(50)))
 
@@ -697,7 +698,7 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
            Column('category', String(70)))
 
         cars = Table('cars', metadata,
-           Column('car_id', Integer, primary_key=True),
+           Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False),
            Column('owner', Integer, ForeignKey('people.person_id'), nullable=False))
 
@@ -786,13 +787,13 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
         e = exists([Car.owner], Car.owner==employee_join.c.person_id)
         Query(Person)._adapt_clause(employee_join, False, False)
         
-        r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active")
-        assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
+        r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active").order_by(Person.person_id)
+        eq_(str(list(r)), "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]")
         r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")).order_by(Person.name)
-        assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
+        eq_(str(list(r)), "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]")
 
         r = session.query(Person).filter(exists([1], Car.owner==Person.person_id))
-        assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
+        eq_(str(list(r)), "[Engineer E4, field X, status Status dead]")
 
 class MultiLevelTest(_base.MappedTest):
     @classmethod
@@ -800,7 +801,7 @@ class MultiLevelTest(_base.MappedTest):
         global table_Employee, table_Engineer, table_Manager
         table_Employee = Table( 'Employee', metadata,
             Column( 'name', type_= String(100), ),
-            Column( 'id', primary_key= True, type_= Integer, ),
+            Column( 'id', primary_key= True, type_= Integer, test_needs_autoincrement=True),
             Column( 'atype', type_= String(100), ),
         )
 
@@ -878,7 +879,7 @@ class ManyToManyPolyTest(_base.MappedTest):
         global base_item_table, item_table, base_item_collection_table, collection_table
         base_item_table = Table(
             'base_item', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('child_name', String(255), default=None))
 
         item_table = Table(
@@ -893,7 +894,7 @@ class ManyToManyPolyTest(_base.MappedTest):
 
         collection_table = Table(
             'collection', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', Unicode(255)))
 
     def test_pjoin_compile(self):
@@ -928,7 +929,7 @@ class CustomPKTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global t1, t2
         t1 = Table('t1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('type', String(30), nullable=False),
             Column('data', String(30)))
         # note that the primary key column in t2 is named differently
@@ -1013,7 +1014,7 @@ class InheritingEagerTest(_base.MappedTest):
         global people, employees, tags, peopleTags
 
         people = Table('people', metadata,
-                           Column('id', Integer, primary_key=True),
+                           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                            Column('_type', String(30), nullable=False),
                           )
 
@@ -1023,7 +1024,7 @@ class InheritingEagerTest(_base.MappedTest):
                         )
 
         tags = Table('tags', metadata,
-                           Column('id', Integer, primary_key=True),
+                           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                            Column('label', String(50), nullable=False),
                        )
 
@@ -1074,11 +1075,11 @@ class MissingPolymorphicOnTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global tablea, tableb, tablec, tabled
         tablea = Table('tablea', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('adata', String(50)),
             )
         tableb = Table('tableb', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('aid', Integer, ForeignKey('tablea.id')),
             Column('data', String(50)),
             )
index b2bcb85d54b570b465729508dc492a635fb29b09..4c593e2a38f69e51bce3bbbbbe1a256a847233df 100644 (file)
@@ -2,10 +2,9 @@ from datetime import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
-
 from sqlalchemy.test import testing
 from test.orm import _base
-
+from sqlalchemy.test.schema import Table, Column
 
 class InheritTest(_base.MappedTest):
     """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
@@ -15,14 +14,14 @@ class InheritTest(_base.MappedTest):
         global Product, Detail, Assembly, SpecLine, Document, RasterDocument
 
         products_table = Table('products', metadata,
-           Column('product_id', Integer, primary_key=True),
+           Column('product_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('product_type', String(128)),
            Column('name', String(128)),
            Column('mark', String(128)),
            )
 
         specification_table = Table('specification', metadata,
-            Column('spec_line_id', Integer, primary_key=True),
+            Column('spec_line_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('master_id', Integer, ForeignKey("products.product_id"),
                 nullable=True),
             Column('slave_id', Integer, ForeignKey("products.product_id"),
@@ -31,7 +30,7 @@ class InheritTest(_base.MappedTest):
             )
 
         documents_table = Table('documents', metadata,
-            Column('document_id', Integer, primary_key=True),
+            Column('document_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('document_type', String(128)),
             Column('product_id', Integer, ForeignKey('products.product_id')),
             Column('create_date', DateTime, default=lambda:datetime.now()),
index 5b57e8f4575e2fe2a6ef62b70012b64904fdda31..daf8bf3bd03d30f984d6de73a7cb024f8b905633 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy.engine import default
 from sqlalchemy.test import AssertsCompiledSQL, testing
 from test.orm import _base, _fixtures
 from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.schema import Table, Column
 
 class Company(_fixtures.Base):
     pass
@@ -38,11 +39,11 @@ def _produce_test(select_type):
             global companies, people, engineers, managers, boss, paperwork, machines
 
             companies = Table('companies', metadata,
-               Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True),
+               Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
                Column('name', String(50)))
 
             people = Table('people', metadata,
-               Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+               Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
                Column('company_id', Integer, ForeignKey('companies.company_id')),
                Column('name', String(50)),
                Column('type', String(30)))
@@ -55,7 +56,7 @@ def _produce_test(select_type):
               )
          
             machines = Table('machines', metadata,
-                Column('machine_id', Integer, primary_key=True),
+                Column('machine_id', Integer, primary_key=True, test_needs_autoincrement=True),
                 Column('name', String(50)),
                 Column('engineer_id', Integer, ForeignKey('engineers.person_id')))
             
@@ -71,7 +72,7 @@ def _produce_test(select_type):
                 )
 
             paperwork = Table('paperwork', metadata,
-                Column('paperwork_id', Integer, primary_key=True),
+                Column('paperwork_id', Integer, primary_key=True, test_needs_autoincrement=True),
                 Column('description', String(50)),
                 Column('person_id', Integer, ForeignKey('people.person_id')))
 
@@ -771,7 +772,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, engineers
         people = Table('people', metadata,
-           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('type', String(30)))
 
@@ -831,7 +832,7 @@ class SelfReferentialJ2JTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global people, engineers, managers
         people = Table('people', metadata,
-           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('type', String(30)))
 
@@ -947,7 +948,7 @@ class M2MFilterTest(_base.MappedTest):
         global people, engineers, organizations, engineers_to_org
         
         organizations = Table('organizations', metadata,
-            Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             )
         engineers_to_org = Table('engineers_org', metadata,
@@ -956,7 +957,7 @@ class M2MFilterTest(_base.MappedTest):
         )
         
         people = Table('people', metadata,
-           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('type', String(30)))
 
@@ -1023,7 +1024,7 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL):
 
         class Parent(Base):
            __tablename__ = 'parent'
-           id = Column(Integer, primary_key=True)
+           id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
            cls = Column(String(50))
            __mapper_args__ = dict(polymorphic_on = cls )
 
index a151af4fa29f129b655935a00f257cc673bdc9f7..7c9920f6f8b00e6105a6382604b7621129fef696 100644 (file)
@@ -46,6 +46,6 @@ class InheritingSelectablesTest(MappedTest):
 
         s = sessionmaker(bind=testing.db)()
 
-        assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all()
+        assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).order_by(Foo.b.desc()).all()
         assert [Bar(), Bar()] == s.query(Bar).all()
 
index 70582688576f7cfaea463d0d3e96cff18de92d50..fc30955db8e2831b792754d46cec278fbc50e298 100644 (file)
@@ -5,20 +5,21 @@ from sqlalchemy.orm import *
 from sqlalchemy.test import testing
 from test.orm import _fixtures
 from test.orm._base import MappedTest, ComparableEntity
+from sqlalchemy.test.schema import Table, Column
 
 
 class SingleInheritanceTest(MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('employees', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('manager_data', String(50)),
             Column('engineer_info', String(50)),
             Column('type', String(20)))
 
         Table('reports', metadata,
-              Column('report_id', Integer, primary_key=True),
+              Column('report_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('employee_id', ForeignKey('employees.employee_id')),
               Column('name', String(50)),
         )
@@ -186,7 +187,7 @@ class RelationToSingleTest(MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('employees', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('manager_data', String(50)),
             Column('engineer_info', String(50)),
@@ -195,7 +196,7 @@ class RelationToSingleTest(MappedTest):
         )
         
         Table('companies', metadata,
-            Column('company_id', Integer, primary_key=True),
+            Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
         )
     
@@ -342,7 +343,7 @@ class SingleOnJoinedTest(MappedTest):
         global persons_table, employees_table
         
         persons_table = Table('persons', metadata,
-           Column('person_id', Integer, primary_key=True),
+           Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('type', String(20), nullable=False)
         )
index 89e23fb759139bff3a55c8aa9aa6b665eac5608c..e8ffaa7cad599b2a3d5a6d53981a10f8d7e0f8f9 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy.orm.shard import ShardedSession
 from sqlalchemy.sql import operators
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_
+from nose import SkipTest
 
 # TODO: ShardTest can be turned into a base for further subclasses
 
@@ -14,7 +15,10 @@ class ShardTest(TestBase):
     def setup_class(cls):
         global db1, db2, db3, db4, weather_locations, weather_reports
 
-        db1 = create_engine('sqlite:///shard1.db')
+        try:
+            db1 = create_engine('sqlite:///shard1.db')
+        except ImportError:
+            raise SkipTest('Requires sqlite')
         db2 = create_engine('sqlite:///shard2.db')
         db3 = create_engine('sqlite:///shard3.db')
         db4 = create_engine('sqlite:///shard4.db')
index ee7fb7af94ac29744abadc92298fd24693b40ecb..d537430cc6e53e64e17b835ea59b4b8a6a546841 100644 (file)
@@ -1,8 +1,7 @@
 
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session
 from test.orm import _base
 from sqlalchemy.test.testing import eq_
@@ -15,14 +14,14 @@ class AssociationTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('items', metadata,
-            Column('item_id', Integer, primary_key=True),
+            Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)))
         Table('item_keywords', metadata,
             Column('item_id', Integer, ForeignKey('items.item_id')),
             Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')),
             Column('data', String(40)))
         Table('keywords', metadata,
-            Column('keyword_id', Integer, primary_key=True),
+            Column('keyword_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)))
 
     @classmethod
index 09f0075479aaca48c28207714ecf4bce408bbeab..94a98d9aead5fdac22346245b694d139b818e439 100644 (file)
@@ -8,8 +8,7 @@ import datetime
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, backref, create_session
 from sqlalchemy.test.testing import eq_
 from test.orm import _base
@@ -37,27 +36,24 @@ class EagerTest(_base.MappedTest):
         cls.other_artifacts['false'] = false
 
         Table('owners', metadata ,
-              Column('id', Integer, primary_key=True, nullable=False),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)))
 
         Table('categories', metadata,
-              Column('id', Integer, primary_key=True, nullable=False),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(20)))
 
         Table('tests', metadata ,
-              Column('id', Integer, primary_key=True, nullable=False ),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('owner_id', Integer, ForeignKey('owners.id'),
                      nullable=False),
               Column('category_id', Integer, ForeignKey('categories.id'),
                      nullable=False))
 
         Table('options', metadata ,
-              Column('test_id', Integer, ForeignKey('tests.id'),
-                     primary_key=True, nullable=False),
-              Column('owner_id', Integer, ForeignKey('owners.id'),
-                     primary_key=True, nullable=False),
-              Column('someoption', sa.Boolean, server_default=false,
-                     nullable=False))
+              Column('test_id', Integer, ForeignKey('tests.id'), primary_key=True),
+              Column('owner_id', Integer, ForeignKey('owners.id'), primary_key=True),
+              Column('someoption', sa.Boolean, server_default=false, nullable=False))
 
     @classmethod
     def setup_classes(cls):
@@ -219,7 +215,7 @@ class EagerTest2(_base.MappedTest):
             Column('data', String(50), primary_key=True))
 
         Table('middle', metadata,
-            Column('id', Integer, primary_key = True),
+            Column('id', Integer, primary_key = True, test_needs_autoincrement=True),
             Column('data', String(50)))
 
         Table('right', metadata,
@@ -280,17 +276,15 @@ class EagerTest3(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('datas', metadata,
-              Column('id', Integer, primary_key=True, nullable=False),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('a', Integer, nullable=False))
 
         Table('foo', metadata,
-              Column('data_id', Integer,
-                     ForeignKey('datas.id'),
-                     nullable=False, primary_key=True),
+              Column('data_id', Integer, ForeignKey('datas.id'),primary_key=True),
               Column('bar', Integer))
 
         Table('stats', metadata,
-              Column('id', Integer, primary_key=True, nullable=False ),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data_id', Integer, ForeignKey('datas.id')),
               Column('somedata', Integer, nullable=False ))
 
@@ -364,11 +358,11 @@ class EagerTest4(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('departments', metadata,
-              Column('department_id', Integer, primary_key=True),
+              Column('department_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(50)))
 
         Table('employees', metadata,
-              Column('person_id', Integer, primary_key=True),
+              Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(50)),
               Column('department_id', Integer,
                      ForeignKey('departments.department_id')))
@@ -422,17 +416,15 @@ class EagerTest5(_base.MappedTest):
               Column('x', String(30)))
 
         Table('derived', metadata,
-              Column('uid', String(30), ForeignKey('base.uid'),
-                     primary_key=True),
+              Column('uid', String(30), ForeignKey('base.uid'), primary_key=True),
               Column('y', String(30)))
 
         Table('derivedII', metadata,
-              Column('uid', String(30), ForeignKey('base.uid'),
-                     primary_key=True),
+              Column('uid', String(30), ForeignKey('base.uid'), primary_key=True),
               Column('z', String(30)))
 
         Table('comments', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('uid', String(30), ForeignKey('base.uid')),
               Column('comment', String(30)))
 
@@ -505,21 +497,21 @@ class EagerTest6(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('design_types', metadata,
-            Column('design_type_id', Integer, primary_key=True))
+            Column('design_type_id', Integer, primary_key=True, test_needs_autoincrement=True))
 
         Table('design', metadata,
-              Column('design_id', Integer, primary_key=True),
+              Column('design_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('design_type_id', Integer,
                      ForeignKey('design_types.design_type_id')))
 
         Table('parts', metadata,
-              Column('part_id', Integer, primary_key=True),
+              Column('part_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('design_id', Integer, ForeignKey('design.design_id')),
               Column('design_type_id', Integer,
                      ForeignKey('design_types.design_type_id')))
 
         Table('inherited_part', metadata,
-              Column('ip_id', Integer, primary_key=True),
+              Column('ip_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('part_id', Integer, ForeignKey('parts.part_id')),
               Column('design_id', Integer, ForeignKey('design.design_id')))
 
@@ -573,32 +565,27 @@ class EagerTest7(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('companies', metadata,
-              Column('company_id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('company_name', String(40)))
 
         Table('addresses', metadata,
-              Column('address_id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('address_id', Integer, primary_key=True,test_needs_autoincrement=True),
               Column('company_id', Integer, ForeignKey("companies.company_id")),
               Column('address', String(40)))
 
         Table('phone_numbers', metadata,
-              Column('phone_id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('phone_id', Integer, primary_key=True,test_needs_autoincrement=True),
               Column('address_id', Integer, ForeignKey('addresses.address_id')),
               Column('type', String(20)),
               Column('number', String(10)))
 
         Table('invoices', metadata,
-              Column('invoice_id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('invoice_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('company_id', Integer, ForeignKey("companies.company_id")),
               Column('date', sa.DateTime))
 
         Table('items', metadata,
-              Column('item_id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')),
               Column('code', String(20)),
               Column('qty', Integer))
@@ -722,12 +709,12 @@ class EagerTest8(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('prj', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('created', sa.DateTime ),
               Column('title', sa.Unicode(100)))
 
         Table('task', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('status_id', Integer,
                      ForeignKey('task_status.id'), nullable=False),
               Column('title', sa.Unicode(100)),
@@ -736,19 +723,19 @@ class EagerTest8(_base.MappedTest):
               Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False))
 
         Table('task_status', metadata,
-              Column('id', Integer, primary_key=True))
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True))
 
         Table('task_type', metadata,
-              Column('id', Integer, primary_key=True))
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True))
 
         Table('msg', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('posted', sa.DateTime, index=True,),
               Column('type_id', Integer, ForeignKey('msg_type.id')),
               Column('task_id', Integer, ForeignKey('task.id')))
 
         Table('msg_type', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', sa.Unicode(20)),
               Column('display_name', sa.Unicode(20)))
 
@@ -814,15 +801,15 @@ class EagerTest9(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('accounts', metadata,
-            Column('account_id', Integer, primary_key=True),
+            Column('account_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)))
 
         Table('transactions', metadata,
-            Column('transaction_id', Integer, primary_key=True),
+            Column('transaction_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)))
 
         Table('entries', metadata,
-            Column('entry_id', Integer, primary_key=True),
+            Column('entry_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)),
             Column('account_id', Integer,
                    ForeignKey('accounts.account_id')),
index ca8cef3ad85a72a77c2ddc5d760e2c53d0b77a1b..fa26ec7d7e0742b84bdd0659d813fc36bedc2a66 100644 (file)
@@ -6,7 +6,8 @@ from sqlalchemy import exc as sa_exc
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_
 from test.orm import _base
-import gc
+from sqlalchemy.test.util import gc_collect
+from sqlalchemy.util import cmp, jython
 
 # global for pickling tests
 MyTest = None
@@ -80,7 +81,9 @@ class AttributesTest(_base.ORMTest):
             del o2.__dict__['mt2']
             o2.__dict__[o_mt2_str] = former
 
-            self.assert_(pk_o == pk_o2)
+            # Relies on dict ordering
+            if not jython:
+                self.assert_(pk_o == pk_o2)
 
         # the above is kind of distrurbing, so let's do it again a little
         # differently.  the string-id in serialization thing is just an
@@ -93,7 +96,9 @@ class AttributesTest(_base.ORMTest):
         o4 = pickle.loads(pk_o3)
         pk_o4 = pickle.dumps(o4)
 
-        self.assert_(pk_o3 == pk_o4)
+        # Relies on dict ordering
+        if not jython:
+            self.assert_(pk_o3 == pk_o4)
 
         # and lastly make sure we still have our data after all that.
         # identical serialzation is great, *if* it's complete :)
@@ -117,7 +122,7 @@ class AttributesTest(_base.ORMTest):
         f.bar = "foo"
         assert state.dict == {'bar':'foo', state.manager.STATE_ATTR:state}
         del f
-        gc.collect()
+        gc_collect()
         assert state.obj() is None
         assert state.dict == {}
         
index d0a7b9ded6956fd4b10a319f4beb7b7fb030ded9..c523fb5f0134e71d657d972ce65d9c6d76324be0 100644 (file)
@@ -1,8 +1,7 @@
 
 from sqlalchemy.test.testing import assert_raises, assert_raises_message
 from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref
 from sqlalchemy.orm import attributes, exc as orm_exc
 from sqlalchemy.test import testing
@@ -20,7 +19,7 @@ class O2MCascadeTest(_fixtures.FixtureTest):
         mapper(User, users, properties = dict(
             addresses = relation(Address, cascade="all, delete-orphan", backref="user"),
             orders = relation(
-                mapper(Order, orders), cascade="all, delete-orphan")
+                mapper(Order, orders), cascade="all, delete-orphan", order_by=orders.c.id)
         ))
         mapper(Dingaling,dingalings, properties={
             'address':relation(Address)
@@ -50,16 +49,12 @@ class O2MCascadeTest(_fixtures.FixtureTest):
                     orders=[Order(description="order 3"),
                             Order(description="order 4")]))
 
-        eq_(sess.query(Order).all(),
+        eq_(sess.query(Order).order_by(Order.id).all(),
             [Order(description="order 3"), Order(description="order 4")])
 
         o5 = Order(description="order 5")
         sess.add(o5)
-        try:
-            sess.flush()
-            assert False
-        except orm_exc.FlushError, e:
-            assert "is an orphan" in str(e)
+        assert_raises_message(orm_exc.FlushError, "is an orphan", sess.flush)
 
 
     @testing.resolve_artifact_names
@@ -351,18 +346,15 @@ class M2OCascadeTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("extra", metadata,
-            Column("id", Integer, Sequence("extra_id_seq", optional=True),
-                   primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("prefs_id", Integer, ForeignKey("prefs.id")))
 
         Table('prefs', metadata,
-            Column('id', Integer, Sequence('prefs_id_seq', optional=True),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)))
 
         Table('users', metadata,
-            Column('id', Integer, Sequence('user_id_seq', optional=True),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)),
             Column('pref_id', Integer, ForeignKey('prefs.id')))
 
@@ -453,22 +445,22 @@ class M2OCascadeTest(_base.MappedTest):
         jack.pref = newpref
         jack.pref = newpref
         sess.flush()
-        eq_(sess.query(Pref).all(),
+        eq_(sess.query(Pref).order_by(Pref.id).all(),
             [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")])
 
 class M2OCascadeDeleteTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)),
               Column('t2id', Integer, ForeignKey('t2.id')))
         Table('t2', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)),
               Column('t3id', Integer, ForeignKey('t3.id')))
         Table('t3', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)))
 
     @classmethod
@@ -581,15 +573,15 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)),
               Column('t2id', Integer, ForeignKey('t2.id')))
         Table('t2', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)),
               Column('t3id', Integer, ForeignKey('t3.id')))
         Table('t3', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)))
 
     @classmethod
@@ -696,12 +688,12 @@ class M2MCascadeTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('a', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             test_needs_fk=True
             )
         Table('b', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             test_needs_fk=True
             
@@ -713,7 +705,7 @@ class M2MCascadeTest(_base.MappedTest):
             
             )
         Table('c', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)),
               Column('bid', Integer, ForeignKey('b.id')),
               test_needs_fk=True
@@ -838,15 +830,11 @@ class UnsavedOrphansTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata,
-            Column('user_id', Integer,
-                   Sequence('user_id_seq', optional=True),
-                   primary_key=True),
+            Column('user_id', Integer,primary_key=True, test_needs_autoincrement=True),
             Column('name', String(40)))
 
         Table('addresses', metadata,
-            Column('address_id', Integer,
-                   Sequence('address_id_seq', optional=True),
-                   primary_key=True),
+            Column('address_id', Integer,primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer, ForeignKey('users.user_id')),
             Column('email_address', String(40)))
 
@@ -923,20 +911,17 @@ class UnsavedOrphansTest2(_base.MappedTest):
     @classmethod
     def define_tables(cls, meta):
         Table('orders', meta,
-            Column('id', Integer, Sequence('order_id_seq'),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)))
 
         Table('items', meta,
-            Column('id', Integer, Sequence('item_id_seq'),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('order_id', Integer, ForeignKey('orders.id'),
                    nullable=False),
             Column('name', String(50)))
 
         Table('attributes', meta,
-            Column('id', Integer, Sequence('attribute_id_seq'),
-                   primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('item_id', Integer, ForeignKey('items.id'),
                    nullable=False),
             Column('name', String(50)))
@@ -982,19 +967,13 @@ class UnsavedOrphansTest3(_base.MappedTest):
     @classmethod
     def define_tables(cls, meta):
         Table('sales_reps', meta,
-            Column('sales_rep_id', Integer,
-                   Sequence('sales_rep_id_seq'),
-                   primary_key=True),
+            Column('sales_rep_id', Integer,primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)))
         Table('accounts', meta,
-            Column('account_id', Integer,
-                   Sequence('account_id_seq'),
-                   primary_key=True),
+            Column('account_id', Integer,primary_key=True, test_needs_autoincrement=True),
             Column('balance', Integer))
         Table('customers', meta,
-            Column('customer_id', Integer,
-                   Sequence('customer_id_seq'),
-                   primary_key=True),
+            Column('customer_id', Integer,primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50)),
             Column('sales_rep_id', Integer,
                    ForeignKey('sales_reps.sales_rep_id')),
@@ -1087,19 +1066,19 @@ class DoubleParentOrphanTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('addresses', metadata,
-            Column('address_id', Integer, primary_key=True),
+            Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('street', String(30)),
         )
 
         Table('homes', metadata,
-            Column('home_id', Integer, primary_key=True, key="id"),
+            Column('home_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True),
             Column('description', String(30)),
             Column('address_id', Integer, ForeignKey('addresses.address_id'),
                    nullable=False),
         )
 
         Table('businesses', metadata,
-            Column('business_id', Integer, primary_key=True, key="id"),
+            Column('business_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True),
             Column('description', String(30), key="description"),
             Column('address_id', Integer, ForeignKey('addresses.address_id'),
                    nullable=False),
@@ -1159,10 +1138,10 @@ class CollectionAssignmentOrphanTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('table_a', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(30)))
         Table('table_b', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(30)),
               Column('a_id', Integer, ForeignKey('table_a.id')))
 
@@ -1208,12 +1187,12 @@ class PartialFlushTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("base", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("descr", String(50))
         )
 
         Table("noninh_child", metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('base_id', Integer, ForeignKey('base.id'))
         )
 
index 12ff25c460ca8efeb4d2a68ba7e3cc97ed3f9181..3d1b30bc9cb55578afbf09ba61d55ed0eba81b0c 100644 (file)
@@ -8,12 +8,11 @@ from sqlalchemy.orm.collections import collection
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy import util, exc as sa_exc
-from sqlalchemy.orm import create_session, mapper, relation,     attributes
+from sqlalchemy.orm import create_session, mapper, relation, attributes
 from test.orm import _base
-from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.testing import eq_, assert_raises
 
 class Canary(sa.orm.interfaces.AttributeExtension):
     def __init__(self):
@@ -169,6 +168,13 @@ class CollectionsTest(_base.ORMTest):
                 control[slice(0,-1)] = values
                 assert_eq()
 
+                values = [creator(),creator(),creator()]
+                control[:] = values
+                direct[:] = values
+                def invalid():
+                    direct[slice(0, 6, 2)] = [creator()]
+                assert_raises(ValueError, invalid)
+                
         if hasattr(direct, '__delitem__'):
             e = creator()
             direct.append(e)
@@ -193,7 +199,7 @@ class CollectionsTest(_base.ORMTest):
                 del direct[::2]
                 del control[::2]
                 assert_eq()
-
+            
         if hasattr(direct, 'remove'):
             e = creator()
             direct.append(e)
@@ -202,8 +208,21 @@ class CollectionsTest(_base.ORMTest):
             direct.remove(e)
             control.remove(e)
             assert_eq()
-
-        if hasattr(direct, '__setslice__'):
+        
+        if hasattr(direct, '__setitem__') or hasattr(direct, '__setslice__'):
+            
+            values = [creator(), creator()]
+            direct[:] = values
+            control[:] = values
+            assert_eq()
+            
+            # test slice assignment where
+            # slice size goes over the number of items
+            values = [creator(), creator()]
+            direct[1:3] = values
+            control[1:3] = values
+            assert_eq()
+            
             values = [creator(), creator()]
             direct[0:1] = values
             control[0:1] = values
@@ -228,8 +247,19 @@ class CollectionsTest(_base.ORMTest):
             direct[1::2] = values
             control[1::2] = values
             assert_eq()
+            
+            values = [creator(), creator()]
+            direct[-1:-3] = values
+            control[-1:-3] = values
+            assert_eq()
 
-        if hasattr(direct, '__delslice__'):
+            values = [creator(), creator()]
+            direct[-2:-1] = values
+            control[-2:-1] = values
+            assert_eq()
+            
+
+        if hasattr(direct, '__delitem__') or hasattr(direct, '__delslice__'):
             for i in range(1, 4):
                 e = creator()
                 direct.append(e)
@@ -246,7 +276,7 @@ class CollectionsTest(_base.ORMTest):
             del direct[:]
             del control[:]
             assert_eq()
-
+        
         if hasattr(direct, 'extend'):
             values = [creator(), creator(), creator()]
 
@@ -345,6 +375,45 @@ class CollectionsTest(_base.ORMTest):
         self._test_list(list)
         self._test_list_bulk(list)
 
+    def test_list_setitem_with_slices(self):
+        
+        # this is a "list" that has no __setslice__
+        # or __delslice__ methods.  The __setitem__
+        # and __delitem__ must therefore accept
+        # slice objects (i.e. as in py3k)
+        class ListLike(object):
+            def __init__(self):
+                self.data = list()
+            def append(self, item):
+                self.data.append(item)
+            def remove(self, item):
+                self.data.remove(item)
+            def insert(self, index, item):
+                self.data.insert(index, item)
+            def pop(self, index=-1):
+                return self.data.pop(index)
+            def extend(self):
+                assert False
+            def __len__(self):
+                return len(self.data)
+            def __setitem__(self, key, value):
+                self.data[key] = value
+            def __getitem__(self, key):
+                return self.data[key]
+            def __delitem__(self, key):
+                del self.data[key]
+            def __iter__(self):
+                return iter(self.data)
+            __hash__ = object.__hash__
+            def __eq__(self, other):
+                return self.data == other
+            def __repr__(self):
+                return 'ListLike(%s)' % repr(self.data)
+
+        self._test_adapter(ListLike)
+        self._test_list(ListLike)
+        self._test_list_bulk(ListLike)
+
     def test_list_subclass(self):
         class MyList(list):
             pass
@@ -1343,10 +1412,10 @@ class DictHelpersTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('parents', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('label', String(128)))
         Table('children', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('parent_id', Integer, ForeignKey('parents.id'),
                      nullable=False),
               Column('a', String(128)),
@@ -1481,12 +1550,12 @@ class DictHelpersTest(_base.MappedTest):
 
         class Foo(BaseObject):
             __tablename__ = "foo"
-            id = Column(Integer(), primary_key=True)
+            id = Column(Integer(), primary_key=True, test_needs_autoincrement=True)
             bar_id = Column(Integer, ForeignKey('bar.id'))
             
         class Bar(BaseObject):
             __tablename__ = "bar"
-            id = Column(Integer(), primary_key=True)
+            id = Column(Integer(), primary_key=True, test_needs_autoincrement=True)
             foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id))
             foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id)))
             
@@ -1521,17 +1590,16 @@ class DictHelpersTest(_base.MappedTest):
         collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
         self._test_composite_mapped(collection_class)
 
-# TODO: are these tests redundant vs. the above tests ?
-# remove if so
 class CustomCollectionsTest(_base.MappedTest):
+    """test the integration of collections with mapped classes."""
 
     @classmethod
     def define_tables(cls, metadata):
         Table('sometable', metadata,
-              Column('col1',Integer, primary_key=True),
+              Column('col1',Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)))
         Table('someothertable', metadata,
-              Column('col1', Integer, primary_key=True),
+              Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('scol1', Integer,
                      ForeignKey('sometable.col1')),
               Column('data', String(20)))
@@ -1646,15 +1714,50 @@ class CustomCollectionsTest(_base.MappedTest):
         replaced = set([id(b) for b in f.bars.values()])
         self.assert_(existing != replaced)
 
-    @testing.resolve_artifact_names
     def test_list(self):
+        self._test_list(list)
+
+    def test_list_no_setslice(self):
+        class ListLike(object):
+            def __init__(self):
+                self.data = list()
+            def append(self, item):
+                self.data.append(item)
+            def remove(self, item):
+                self.data.remove(item)
+            def insert(self, index, item):
+                self.data.insert(index, item)
+            def pop(self, index=-1):
+                return self.data.pop(index)
+            def extend(self):
+                assert False
+            def __len__(self):
+                return len(self.data)
+            def __setitem__(self, key, value):
+                self.data[key] = value
+            def __getitem__(self, key):
+                return self.data[key]
+            def __delitem__(self, key):
+                del self.data[key]
+            def __iter__(self):
+                return iter(self.data)
+            __hash__ = object.__hash__
+            def __eq__(self, other):
+                return self.data == other
+            def __repr__(self):
+                return 'ListLike(%s)' % repr(self.data)
+        
+        self._test_list(ListLike)
+        
+    @testing.resolve_artifact_names
+    def _test_list(self, listcls):
         class Parent(object):
             pass
         class Child(object):
             pass
 
         mapper(Parent, sometable, properties={
-            'children':relation(Child, collection_class=list)
+            'children':relation(Child, collection_class=listcls)
         })
         mapper(Child, someothertable)
 
index fe77b360187e894d7a692ec8706dd84a203f9009..6fbfe7fe1823055d11095c330eb63fbdcdb585be 100644 (file)
@@ -7,8 +7,7 @@ T1/T2.
 """
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, backref, create_session
 from sqlalchemy.test.testing import eq_
 from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf
@@ -138,7 +137,7 @@ class SelfReferentialNoPKTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('item', metadata,
-           Column('id', Integer, primary_key=True),
+           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('uuid', String(32), unique=True, nullable=False),
            Column('parent_uuid', String(32), ForeignKey('item.uuid'),
                   nullable=True))
@@ -190,18 +189,16 @@ class InheritTestOne(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("parent", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("parent_data", String(50)),
             Column("type", String(10)))
 
         Table("child1", metadata,
-              Column("id", Integer, ForeignKey("parent.id"),
-                     primary_key=True),
+              Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
               Column("child1_data", String(50)))
 
         Table("child2", metadata,
-            Column("id", Integer, ForeignKey("parent.id"),
-                   primary_key=True),
+            Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
             Column("child1_id", Integer, ForeignKey("child1.id"),
                    nullable=False),
             Column("child2_data", String(50)))
@@ -262,7 +259,7 @@ class InheritTestTwo(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('a', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('cid', Integer, ForeignKey('c.id')))
 
@@ -271,7 +268,7 @@ class InheritTestTwo(_base.MappedTest):
             Column('data', String(30)))
 
         Table('c', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('aid', Integer,
                    ForeignKey('a.id', use_alter=True, name="foo")))
@@ -311,16 +308,16 @@ class BiDirectionalManyToOneTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('t2id', Integer, ForeignKey('t2.id')))
         Table('t2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('t1id', Integer,
                    ForeignKey('t1.id', use_alter=True, name="foo_fk")))
         Table('t3', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('t1id', Integer, ForeignKey('t1.id'), nullable=False),
             Column('t2id', Integer, ForeignKey('t2.id'), nullable=False))
@@ -402,13 +399,11 @@ class BiDirectionalOneToManyTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-              Column('c1', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('c2', Integer, ForeignKey('t2.c1')))
 
         Table('t2', metadata,
-              Column('c1', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('c2', Integer,
                      ForeignKey('t1.c1', use_alter=True, name='t1c1_fk')))
 
@@ -453,18 +448,18 @@ class BiDirectionalOneToManyTest2(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-              Column('c1', Integer, primary_key=True),
+              Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('c2', Integer, ForeignKey('t2.c1')),
               test_needs_autoincrement=True)
 
         Table('t2', metadata,
-              Column('c1', Integer, primary_key=True),
+              Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('c2', Integer,
                      ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')),
               test_needs_autoincrement=True)
 
         Table('t1_data', metadata,
-              Column('c1', Integer, primary_key=True),
+              Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('t1id', Integer, ForeignKey('t1.c1')),
               Column('data', String(20)),
               test_needs_autoincrement=True)
@@ -530,15 +525,13 @@ class OneToManyManyToOneTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('ball', metadata,
-              Column('id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('person_id', Integer,
                      ForeignKey('person.id', use_alter=True, name='fk_person_id')),
               Column('data', String(30)))
 
         Table('person', metadata,
-              Column('id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
               Column('data', String(30)))
 
@@ -841,7 +834,7 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("a_table", metadata,
-              Column("id", Integer(), primary_key=True),
+              Column("id", Integer(), primary_key=True, test_needs_autoincrement=True),
               Column("fui", String(128)),
               Column("b", Integer(), ForeignKey("a_table.id")))
 
index b063780ac72b0059e38dec2fb22712446dcf8b3c..5379c9714995b5d46170bc2d014a12fba1970b7d 100644 (file)
@@ -2,8 +2,7 @@
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session
 from test.orm import _base
 from sqlalchemy.test.testing import eq_
@@ -15,7 +14,7 @@ class TriggerDefaultsTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         dt = Table('dt', metadata,
-                   Column('id', Integer, primary_key=True),
+                   Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                    Column('col1', String(20)),
                    Column('col2', String(20),
                           server_default=sa.schema.FetchedValue()),
@@ -34,17 +33,23 @@ class TriggerDefaultsTest(_base.MappedTest):
                    "UPDATE dt SET col2='ins', col4='ins' "
                    "WHERE dt.id IN (SELECT id FROM inserted);",
                    on='mssql'),
-            ):
-            if testing.against(ins.on):
-                break
-        else:
-            ins = sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt "
+            sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT "
+                     "ON dt "
+                     "FOR EACH ROW "
+                     "BEGIN "
+                     ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;",
+                     on='oracle'),
+            sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt "
                          "FOR EACH ROW BEGIN "
-                         "SET NEW.col2='ins'; SET NEW.col4='ins'; END")
-        ins.execute_at('after-create', dt)
+                         "SET NEW.col2='ins'; SET NEW.col4='ins'; END",
+                         on=lambda event, schema_item, bind, **kw: 
+                                bind.engine.name not in ('oracle', 'mssql', 'sqlite')
+                ),
+            ):
+            ins.execute_at('after-create', dt)
+            
         sa.DDL("DROP TRIGGER dt_ins").execute_at('before-drop', dt)
 
-
         for up in (
             sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt "
                    "FOR EACH ROW BEGIN "
@@ -55,14 +60,19 @@ class TriggerDefaultsTest(_base.MappedTest):
                    "UPDATE dt SET col3='up', col4='up' "
                    "WHERE dt.id IN (SELECT id FROM deleted);",
                    on='mssql'),
-            ):
-            if testing.against(up.on):
-                break
-        else:
-            up = sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
+            sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
+                  "FOR EACH ROW BEGIN "
+                  ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;",
+                  on='oracle'),
+            sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
                         "FOR EACH ROW BEGIN "
-                        "SET NEW.col3='up'; SET NEW.col4='up'; END")
-        up.execute_at('after-create', dt)
+                        "SET NEW.col3='up'; SET NEW.col4='up'; END",
+                        on=lambda event, schema_item, bind, **kw: 
+                                bind.engine.name not in ('oracle', 'mssql', 'sqlite')
+                    ),
+            ):
+            up.execute_at('after-create', dt)
+
         sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt)
 
 
@@ -115,7 +125,7 @@ class ExcludedDefaultsTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         dt = Table('dt', metadata,
-                   Column('id', Integer, primary_key=True),
+                   Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                    Column('col1', String(20), default="hello"),
         )
         
index f2089a4351b2e9f6ddc4693ef3d1ba74409fd15d..23a5fc87625fb5a7cc58bafbcb485251c341ec31 100644 (file)
@@ -3,8 +3,7 @@ import operator
 from sqlalchemy.orm import dynamic_loader, backref
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey, desc, select, func
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session, Query, attributes
 from sqlalchemy.orm.dynamic import AppenderMixin
 from sqlalchemy.test.testing import eq_
@@ -344,7 +343,8 @@ class SessionTest(_fixtures.FixtureTest):
         sess.flush()
         sess.commit()
         u1.addresses.append(Address(email_address='foo@bar.com'))
-        eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+        eq_(u1.addresses.order_by(Address.id).all(), 
+                 [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
         sess.rollback()
         eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
 
@@ -502,13 +502,13 @@ class DontDereferenceTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(40)),
               Column('fullname', String(100)),
               Column('password', String(15)))
 
         Table('addresses', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('email_address', String(100), nullable=False),
               Column('user_id', Integer, ForeignKey('users.id')))
 
index 384e0472f6c1d5862caa0dacaa371687304113d3..425c08c61006dbbbe003875e7c254312ba898964 100644 (file)
@@ -5,8 +5,7 @@ import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy.orm import eagerload, deferred, undefer
 from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased
 from sqlalchemy.test.testing import eq_
 from sqlalchemy.test.assertsql import CompiledSQL
@@ -459,20 +458,14 @@ class EagerTest(_fixtures.FixtureTest):
         })
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id),
-            'orders':relation(Order, lazy=True)
+            'orders':relation(Order, lazy=True, order_by=orders.c.id)
         })
 
         sess = create_session()
         q = sess.query(User)
 
-        if testing.against('mysql'):
-            l = q.limit(2).all()
-            assert self.static.user_all_result[:2] == l
-        else:
-            l = q.order_by(User.id).limit(2).offset(1).all()
-            print self.static.user_all_result[1:3]
-            print l
-            assert self.static.user_all_result[1:3] == l
+        l = q.order_by(User.id).limit(2).offset(1).all()
+        eq_(self.static.user_all_result[1:3], l)
 
     @testing.resolve_artifact_names
     def test_distinct(self):
@@ -483,15 +476,15 @@ class EagerTest(_fixtures.FixtureTest):
         s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
 
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy=False),
+            'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id),
         })
 
         sess = create_session()
         q = sess.query(User)
 
         def go():
-            l = q.filter(s.c.u2_id==User.id).distinct().all()
-            assert self.static.user_address_result == l
+            l = q.filter(s.c.u2_id==User.id).distinct().order_by(User.id).all()
+            eq_(self.static.user_address_result, l)
         self.assert_sql_count(testing.db, go, 1)
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
@@ -656,9 +649,12 @@ class EagerTest(_fixtures.FixtureTest):
 
         mapper(Order, orders)
         mapper(User, users, properties={
-               'orders':relation(Order, backref='user', lazy=False),
-               'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False)
+               'orders':relation(Order, backref='user', lazy=False, order_by=orders.c.id),
+               'max_order':relation(
+                                mapper(Order, max_orders, non_primary=True), 
+                                lazy=False, uselist=False)
                })
+
         q = create_session().query(User)
 
         def go():
@@ -675,7 +671,7 @@ class EagerTest(_fixtures.FixtureTest):
                     max_order=Order(id=4)
                 ),
                 User(id=10),
-            ] == q.all()
+            ] == q.order_by(User.id).all()
         self.assert_sql_count(testing.db, go, 1)
 
     @testing.resolve_artifact_names
@@ -823,15 +819,15 @@ class OrderBySecondaryTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('m2m', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('aid', Integer, ForeignKey('a.id')),
               Column('bid', Integer, ForeignKey('b.id')))
 
         Table('a', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)))
         Table('b', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)))
 
     @classmethod
@@ -873,8 +869,7 @@ class SelfReferentialEagerTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('nodes', metadata,
-              Column('id', Integer, sa.Sequence('node_id_seq', optional=True),
-                     primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('parent_id', Integer, ForeignKey('nodes.id')),
             Column('data', String(30)))
 
@@ -1088,11 +1083,11 @@ class MixedSelfReferentialEagerTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('a_table', metadata,
-                       Column('id', Integer, primary_key=True)
+                       Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
                        )
 
         Table('b_table', metadata,
-                       Column('id', Integer, primary_key=True),
+                       Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                        Column('parent_b1_id', Integer, ForeignKey('b_table.id')),
                        Column('parent_a_id', Integer, ForeignKey('a_table.id')),
                        Column('parent_b2_id', Integer, ForeignKey('b_table.id')))
@@ -1161,7 +1156,7 @@ class SelfReferentialM2MEagerTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('widget', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', sa.Unicode(40), nullable=False, unique=True),
         )
 
@@ -1244,7 +1239,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             )
         self.assert_sql_count(testing.db, go, 1)
 
-    @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs")
+    @testing.exclude('sqlite', '>', (0, ), "sqlite flat out blows it on the multiple JOINs")
     @testing.resolve_artifact_names
     def test_two_entities_with_joins(self):
         sess = create_session()
@@ -1337,13 +1332,13 @@ class CyclicalInheritingEagerTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-            Column('c1', Integer, primary_key=True),
+            Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('c2', String(30)),
             Column('type', String(30))
             )
 
         Table('t2', metadata,
-            Column('c1', Integer, primary_key=True),
+            Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('c2', String(30)),
             Column('type', String(30)),
             Column('t1.id', Integer, ForeignKey('t1.c1')))
@@ -1376,12 +1371,12 @@ class SubqueryTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users_table', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(16))
         )
 
         Table('tags_table', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer, ForeignKey("users_table.id")),
             Column('score1', sa.Float),
             Column('score2', sa.Float),
@@ -1461,16 +1456,20 @@ class CorrelatedSubqueryTest(_base.MappedTest):
     Exercises a variety of ways to configure this.
     
     """
+
+    # another argument for eagerload learning about inner joins
+    
+    __requires__ = ('correlated_outer_joins', )
     
     @classmethod
     def define_tables(cls, metadata):
         users = Table('users', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50))
             )
 
         stuff = Table('stuff', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('date', Date),
             Column('user_id', Integer, ForeignKey('users.id')))
     
@@ -1549,11 +1548,13 @@ class CorrelatedSubqueryTest(_base.MappedTest):
 
         if ondate:
             # the more 'relational' way to do this, join on the max date
-            stuff_view = select([func.max(salias.c.date).label('max_date')]).where(salias.c.user_id==users.c.id).correlate(users)
+            stuff_view = select([func.max(salias.c.date).label('max_date')]).\
+                                where(salias.c.user_id==users.c.id).correlate(users)
         else:
             # a common method with the MySQL crowd, which actually might perform better in some
             # cases - subquery does a limit with order by DESC, join on the id
-            stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).correlate(users).order_by(salias.c.date.desc()).limit(1)
+            stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).\
+                                    correlate(users).order_by(salias.c.date.desc()).limit(1)
 
         if labeled == 'label':
             stuff_view = stuff_view.label('foo')
index 65934989788a013879909c3fd18fe397c5d95501..c602ac963f3607592f36d8f90018e0452a47bf36 100644 (file)
@@ -1,7 +1,7 @@
 """Attribute/instance expiration, deferral of attributes, etc."""
 
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-import gc
+from sqlalchemy.test.util import gc_collect
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc
@@ -666,7 +666,7 @@ class ExpireTest(_fixtures.FixtureTest):
         assert self.static.user_address_result == userlist
         assert len(list(sess)) == 9
         sess.expire_all()
-        gc.collect()
+        gc_collect()
         assert len(list(sess)) == 4 # since addresses were gc'ed
 
         userlist = sess.query(User).order_by(User.id).all()
index 0efc1814ed6e0adb632f205358d1be31387db2ae..8f61d4d1483fea975da075c7cb5001ae500743ba 100644 (file)
@@ -70,12 +70,17 @@ class GenerativeQueryTest(_base.MappedTest):
         assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,)
         
         assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,)
+        # Py3K
+        #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29
+        #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29
+        # Py2K
         assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
         assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
-
+        # end Py2K
+        
     @testing.resolve_artifact_names
     def test_aggregate_1(self):
-        if (testing.against('mysql') and
+        if (testing.against('mysql') and not testing.against('+zxjdbc') and
             testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')):
             return
 
@@ -95,10 +100,18 @@ class GenerativeQueryTest(_base.MappedTest):
     def test_aggregate_3(self):
         query = create_session().query(Foo)
 
+        # Py3K
+        #avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0]
+        # Py2K
         avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0]
+        # end Py2K
         assert round(avg_f, 1) == 14.5
 
+        # Py3K
+        #avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0]
+        # Py2K
         avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0]
+        # end Py2K
         assert round(avg_o, 1) == 14.5
 
     @testing.resolve_artifact_names
index b4c8f8601c0d1475ecb2c157a1c8d4b8e718487d..6390e2596356da23cb72f5fddfe177b8f9620f5d 100644 (file)
@@ -488,7 +488,7 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class B(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
 
-        assert_raises(TypeError, attributes.register_class, B)
+        assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B)
 
     def test_single_up(self):
 
@@ -499,7 +499,8 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class B(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         attributes.register_class(B)
-        assert_raises(TypeError, attributes.register_class, A)
+
+        assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, A)
 
     def test_diamond_b1(self):
         mgr_factory = lambda cls: attributes.ClassManager(cls)
@@ -507,10 +508,10 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class A(object): pass
         class B1(A): pass
         class B2(A):
-            __sa_instrumentation_manager__ = mgr_factory
+            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         class C(object): pass
 
-        assert_raises(TypeError, attributes.register_class, B1)
+        assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1)
 
     def test_diamond_b2(self):
         mgr_factory = lambda cls: attributes.ClassManager(cls)
@@ -518,10 +519,11 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class A(object): pass
         class B1(A): pass
         class B2(A):
-            __sa_instrumentation_manager__ = mgr_factory
+            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         class C(object): pass
 
-        assert_raises(TypeError, attributes.register_class, B2)
+        attributes.register_class(B2)
+        assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1)
 
     def test_diamond_c_b(self):
         mgr_factory = lambda cls: attributes.ClassManager(cls)
@@ -529,12 +531,12 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class A(object): pass
         class B1(A): pass
         class B2(A):
-            __sa_instrumentation_manager__ = mgr_factory
+            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         class C(object): pass
 
         attributes.register_class(C)
-        assert_raises(TypeError, attributes.register_class, B1)
 
+        assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1)
 
 class OnLoadTest(_base.ORMTest):
     """Check that Events.on_load is not hit in regular attributes operations."""
index 819f29911ebf154dc09b983a6fdae7d0807686cb..8c196cfcfbf3d257a0058554ddd1991212db6018 100644 (file)
@@ -163,9 +163,8 @@ class LazyTest(_fixtures.FixtureTest):
         # use a union all to get a lot of rows to join against
         u2 = users.alias('u2')
         s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
-        print [key for key in s.c.keys()]
-        l = q.filter(s.c.u2_id==User.id).distinct().all()
-        assert self.static.user_all_result == l
+        l = q.filter(s.c.u2_id==User.id).order_by(User.id).distinct().all()
+        eq_(self.static.user_all_result, l)
 
     @testing.resolve_artifact_names
     def test_one_to_many_scalar(self):
index 13913578a55b8c0d2d0de07628bbea0a5f830031..c34ccdbab852b6def978bceb2a6a86f71236c227 100644 (file)
@@ -3,9 +3,8 @@
 from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import sqlalchemy as sa
 from sqlalchemy.test import testing, pickleable
-from sqlalchemy import MetaData, Integer, String, ForeignKey, func
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy import MetaData, Integer, String, ForeignKey, func, util
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.engine import default
 from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
 from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
@@ -390,7 +389,7 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_self_ref_synonym(self):
         t = Table('nodes', MetaData(),
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('parent_id', Integer, ForeignKey('nodes.id')))
 
         class Node(object):
@@ -432,7 +431,7 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_prop_filters(self):
         t = Table('person', MetaData(),
-                  Column('id', Integer, primary_key=True),
+                  Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                   Column('type', String(128)),
                   Column('name', String(128)),
                   Column('employee_number', Integer),
@@ -870,6 +869,7 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_comparable_column(self):
         class MyComparator(sa.orm.properties.ColumnProperty.Comparator):
+            __hash__ = None
             def __eq__(self, other):
                 # lower case comparison
                 return func.lower(self.__clause_element__()) == func.lower(other)
@@ -1451,12 +1451,12 @@ class DeferredTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_group(self):
         """Deferred load with a group"""
-        mapper(Order, orders, properties={
-            'userident': deferred(orders.c.user_id, group='primary'),
-            'addrident': deferred(orders.c.address_id, group='primary'),
-            'description': deferred(orders.c.description, group='primary'),
-            'opened': deferred(orders.c.isopen, group='primary')
-        })
+        mapper(Order, orders, properties=util.OrderedDict([
+            ('userident', deferred(orders.c.user_id, group='primary')),
+            ('addrident', deferred(orders.c.address_id, group='primary')),
+            ('description', deferred(orders.c.description, group='primary')),
+            ('opened', deferred(orders.c.isopen, group='primary'))
+        ]))
 
         sess = create_session()
         q = sess.query(Order).order_by(Order.id)
@@ -1562,10 +1562,12 @@ class DeferredTest(_fixtures.FixtureTest):
 
     @testing.resolve_artifact_names
     def test_undefer_group(self):
-        mapper(Order, orders, properties={
-            'userident':deferred(orders.c.user_id, group='primary'),
-            'description':deferred(orders.c.description, group='primary'),
-            'opened':deferred(orders.c.isopen, group='primary')})
+        mapper(Order, orders, properties=util.OrderedDict([
+            ('userident',deferred(orders.c.user_id, group='primary')),
+            ('description',deferred(orders.c.description, group='primary')),
+            ('opened',deferred(orders.c.isopen, group='primary'))
+            ]
+            ))
 
         sess = create_session()
         q = sess.query(Order).order_by(Order.id)
@@ -1796,11 +1798,11 @@ class DeferredPopulationTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("thing", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("name", String(20)))
 
         Table("human", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("thing_id", Integer, ForeignKey("thing.id")),
             Column("name", String(20)))
 
@@ -1884,13 +1886,12 @@ class CompositeTypesTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('graphs', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('version_id', Integer, primary_key=True, nullable=True),
             Column('name', String(30)))
 
         Table('edges', metadata,
-            Column('id', Integer, primary_key=True,
-                   test_needs_autoincrement=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('graph_id', Integer, nullable=False),
             Column('graph_version_id', Integer, nullable=False),
             Column('x1', Integer),
@@ -1902,7 +1903,7 @@ class CompositeTypesTest(_base.MappedTest):
             ['graphs.id', 'graphs.version_id']))
 
         Table('foobars', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('x1', Integer, default=2),
             Column('x2', Integer),
             Column('x3', Integer, default=15),
@@ -2041,7 +2042,7 @@ class CompositeTypesTest(_base.MappedTest):
 
         # test pk with one column NULL
         # TODO: can't seem to get NULL in for a PK value
-        # in either mysql or postgres, autoincrement=False etc.
+        # in either mysql or postgresql, autoincrement=False etc.
         # notwithstanding
         @testing.fails_on_everything_except("sqlite")
         def go():
@@ -2475,33 +2476,26 @@ class RequirementsTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('ht1', metadata,
-              Column('id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('value', String(10)))
         Table('ht2', metadata,
-              Column('id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('ht1_id', Integer, ForeignKey('ht1.id')),
               Column('value', String(10)))
         Table('ht3', metadata,
-              Column('id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('value', String(10)))
         Table('ht4', metadata,
-              Column('ht1_id', Integer, ForeignKey('ht1.id'),
-                     primary_key=True),
-              Column('ht3_id', Integer, ForeignKey('ht3.id'),
-                     primary_key=True))
+              Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True),
+              Column('ht3_id', Integer, ForeignKey('ht3.id'), primary_key=True))
         Table('ht5', metadata,
-              Column('ht1_id', Integer, ForeignKey('ht1.id'),
-                     primary_key=True))
+              Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True))
         Table('ht6', metadata,
-              Column('ht1a_id', Integer, ForeignKey('ht1.id'),
-                     primary_key=True),
-              Column('ht1b_id', Integer, ForeignKey('ht1.id'),
-                     primary_key=True),
+              Column('ht1a_id', Integer, ForeignKey('ht1.id'), primary_key=True),
+              Column('ht1b_id', Integer, ForeignKey('ht1.id'), primary_key=True),
               Column('value', String(10)))
 
+    # Py2K
     @testing.resolve_artifact_names
     def test_baseclass(self):
         class OldStyle:
@@ -2516,7 +2510,8 @@ class RequirementsTest(_base.MappedTest):
 
         # TODO: is weakref support detectable without an instance?
         #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2)
-
+    # end Py2K
+    
     @testing.resolve_artifact_names
     def test_comparison_overrides(self):
         """Simple tests to ensure users can supply comparison __methods__.
@@ -2618,12 +2613,12 @@ class MagicNamesTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('cartographers', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(50)),
               Column('alias', String(50)),
               Column('quip', String(100)))
         Table('maps', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('cart_id', Integer,
                      ForeignKey('cartographers.id')),
               Column('state', String(2)),
@@ -2665,7 +2660,7 @@ class MagicNamesTest(_base.MappedTest):
         for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR,
                          sa.orm.attributes.ClassManager.MANAGER_ATTR):
             t = Table('t', sa.MetaData(),
-                      Column('id', Integer, primary_key=True),
+                      Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
                       Column(reserved, Integer))
             class T(object):
                 pass
index f4e3872b06a83d2a6f3144640487998de17b2b8e..5433515caa5cf6a6d09dccd9722fda30469df325 100644 (file)
@@ -1,13 +1,13 @@
 from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import sqlalchemy as sa
-from sqlalchemy import Table, Column, Integer, PickleType
+from sqlalchemy import Integer, PickleType
 import operator
 from sqlalchemy.test import testing
 from sqlalchemy.util import OrderedSet
 from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property, sessionmaker
 from sqlalchemy.test.testing import eq_, ne_
 from test.orm import _base, _fixtures
-
+from sqlalchemy.test.schema import Table, Column
 
 class MergeTest(_fixtures.FixtureTest):
     """Session.merge() functionality"""
@@ -103,6 +103,7 @@ class MergeTest(_fixtures.FixtureTest):
             'addresses':relation(Address,
                         backref='user',
                         collection_class=OrderedSet,
+                                order_by=addresses.c.id,
                                  cascade="all, delete-orphan")
         })
         mapper(Address, addresses)
@@ -154,6 +155,7 @@ class MergeTest(_fixtures.FixtureTest):
         mapper(User, users, properties={
             'addresses':relation(Address,
                                  backref='user',
+                                 order_by=addresses.c.id,
                                  collection_class=OrderedSet)})
         mapper(Address, addresses)
         on_load = self.on_load_tracker(User)
@@ -300,20 +302,20 @@ class MergeTest(_fixtures.FixtureTest):
 
         # test with "dontload" merge
         sess5 = create_session()
-        u = sess5.merge(u, dont_load=True)
+        u = sess5.merge(u, load=False)
         assert len(u.addresses)
         for a in u.addresses:
             assert a.user is u
         def go():
             sess5.flush()
         # no changes; therefore flush should do nothing
-        # but also, dont_load wipes out any difference in committed state,
+        # but also, load=False wipes out any difference in committed state,
         # so no flush at all
         self.assert_sql_count(testing.db, go, 0)
         eq_(on_load.called, 15)
 
         sess4 = create_session()
-        u = sess4.merge(u, dont_load=True)
+        u = sess4.merge(u, load=False)
         # post merge change
         u.addresses[1].email_address='afafds'
         def go():
@@ -445,17 +447,35 @@ class MergeTest(_fixtures.FixtureTest):
         assert u3 is u
 
     @testing.resolve_artifact_names
-    def test_transient_dontload(self):
+    def test_transient_no_load(self):
         mapper(User, users)
 
         sess = create_session()
         u = User()
-        assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+        assert_raises_message(sa.exc.InvalidRequestError, "load=False option does not support", sess.merge, u, load=False)
 
+    @testing.resolve_artifact_names
+    def test_dont_load_deprecated(self):
+        mapper(User, users)
+
+        sess = create_session()
+        u = User(name='ed')
+        sess.add(u)
+        sess.flush()
+        u = sess.query(User).first()
+        sess.expunge(u)
+        sess.execute(users.update().values(name='jack'))
+        @testing.uses_deprecated("dont_load=True has been renamed")
+        def go():
+            u1 = sess.merge(u, dont_load=True)
+            assert u1 in sess
+            assert u1.name=='ed'
+            assert u1 not in sess.dirty
+        go()
 
     @testing.resolve_artifact_names
-    def test_dontload_with_backrefs(self):
-        """dontload populates relations in both directions without requiring a load"""
+    def test_no_load_with_backrefs(self):
+        """load=False populates relations in both directions without requiring a load"""
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), backref='user')
         })
@@ -470,7 +490,7 @@ class MergeTest(_fixtures.FixtureTest):
         assert 'user' in u.addresses[1].__dict__
 
         sess = create_session()
-        u2 = sess.merge(u, dont_load=True)
+        u2 = sess.merge(u, load=False)
         assert 'user' in u2.addresses[1].__dict__
         eq_(u2.addresses[1].user, User(id=7, name='fred'))
 
@@ -479,7 +499,7 @@ class MergeTest(_fixtures.FixtureTest):
         sess.close()
 
         sess = create_session()
-        u = sess.merge(u2, dont_load=True)
+        u = sess.merge(u2, load=False)
         assert 'user' not in u.addresses[1].__dict__
         eq_(u.addresses[1].user, User(id=7, name='fred'))
 
@@ -488,12 +508,12 @@ class MergeTest(_fixtures.FixtureTest):
     def test_dontload_with_eager(self):
         """
 
-        This test illustrates that with dont_load=True, we can't just copy the
+        This test illustrates that with load=False, we can't just copy the
         committed_state of the merged instance over; since it references
         collection objects which themselves are to be merged.  This
         committed_state would instead need to be piecemeal 'converted' to
         represent the correct objects.  However, at the moment I'd rather not
-        support this use case; if you are merging with dont_load=True, you're
+        support this use case; if you are merging with load=False, you're
         typically dealing with caching and the merged objects shouldnt be
         'dirty'.
 
@@ -516,16 +536,16 @@ class MergeTest(_fixtures.FixtureTest):
         u2 = sess2.query(User).options(sa.orm.eagerload('addresses')).get(7)
 
         sess3 = create_session()
-        u3 = sess3.merge(u2, dont_load=True)
+        u3 = sess3.merge(u2, load=False)
         def go():
             sess3.flush()
         self.assert_sql_count(testing.db, go, 0)
 
     @testing.resolve_artifact_names
-    def test_dont_load_disallows_dirty(self):
-        """dont_load doesnt support 'dirty' objects right now
+    def test_no_load_disallows_dirty(self):
+        """load=False doesnt support 'dirty' objects right now
 
-        (see test_dont_load_with_eager()). Therefore lets assert it.
+        (see test_no_load_with_eager()). Therefore lets assert it.
 
         """
         mapper(User, users)
@@ -539,17 +559,17 @@ class MergeTest(_fixtures.FixtureTest):
         u.name = 'ed'
         sess2 = create_session()
         try:
-            sess2.merge(u, dont_load=True)
+            sess2.merge(u, load=False)
             assert False
         except sa.exc.InvalidRequestError, e:
-            assert ("merge() with dont_load=True option does not support "
+            assert ("merge() with load=False option does not support "
                     "objects marked as 'dirty'.  flush() all changes on mapped "
-                    "instances before merging with dont_load=True.") in str(e)
+                    "instances before merging with load=False.") in str(e)
 
         u2 = sess2.query(User).get(7)
 
         sess3 = create_session()
-        u3 = sess3.merge(u2, dont_load=True)
+        u3 = sess3.merge(u2, load=False)
         assert not sess3.dirty
         def go():
             sess3.flush()
@@ -557,7 +577,7 @@ class MergeTest(_fixtures.FixtureTest):
 
 
     @testing.resolve_artifact_names
-    def test_dont_load_sets_backrefs(self):
+    def test_no_load_sets_backrefs(self):
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses),backref='user')})
 
@@ -575,17 +595,17 @@ class MergeTest(_fixtures.FixtureTest):
         assert u.addresses[0].user is u
 
         sess2 = create_session()
-        u2 = sess2.merge(u, dont_load=True)
+        u2 = sess2.merge(u, load=False)
         assert not sess2.dirty
         def go():
             assert u2.addresses[0].user is u2
         self.assert_sql_count(testing.db, go, 0)
 
     @testing.resolve_artifact_names
-    def test_dont_load_preserves_parents(self):
-        """Merge with dont_load does not trigger a 'delete-orphan' operation.
+    def test_no_load_preserves_parents(self):
+        """Merge with load=False does not trigger a 'delete-orphan' operation.
 
-        merge with dont_load sets attributes without using events.  this means
+        merge with load=False sets attributes without using events.  this means
         the 'hasparent' flag is not propagated to the newly merged instance.
         in fact this works out OK, because the '_state.parents' collection on
         the newly merged instance is empty; since the mapper doesn't see an
@@ -610,7 +630,7 @@ class MergeTest(_fixtures.FixtureTest):
         assert u.addresses[0].user is u
 
         sess2 = create_session()
-        u2 = sess2.merge(u, dont_load=True)
+        u2 = sess2.merge(u, load=False)
         assert not sess2.dirty
         a2 = u2.addresses[0]
         a2.email_address='somenewaddress'
@@ -624,19 +644,19 @@ class MergeTest(_fixtures.FixtureTest):
 
         # this use case is not supported; this is with a pending Address on
         # the pre-merged object, and we currently dont support 'dirty' objects
-        # being merged with dont_load=True.  in this case, the empty
+        # being merged with load=False.  in this case, the empty
         # '_state.parents' collection would be an issue, since the optimistic
         # flag is False in _is_orphan() for pending instances.  so if we start
-        # supporting 'dirty' with dont_load=True, this test will need to pass
+        # supporting 'dirty' with load=False, this test will need to pass
         sess = create_session()
         u = sess.query(User).get(7)
         u.addresses.append(Address())
         sess2 = create_session()
         try:
-            u2 = sess2.merge(u, dont_load=True)
+            u2 = sess2.merge(u, load=False)
             assert False
 
-            # if dont_load is changed to support dirty objects, this code
+            # if load=False is changed to support dirty objects, this code
             # needs to pass
             a2 = u2.addresses[0]
             a2.email_address='somenewaddress'
@@ -647,7 +667,7 @@ class MergeTest(_fixtures.FixtureTest):
             eq_(sess2.query(User).get(u2.id).addresses[0].email_address,
                 'somenewaddress')
         except sa.exc.InvalidRequestError, e:
-            assert "dont_load=True option does not support" in str(e)
+            assert "load=False option does not support" in str(e)
 
     @testing.resolve_artifact_names
     def test_synonym_comparable(self):
@@ -737,7 +757,7 @@ class MutableMergeTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("data", metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', PickleType(comparator=operator.eq))
         )
     
index 1376c402e755f1c993ae8981fe6120b125609bf7..e99bfb794b0746ec785358fa107a4e3848e6fd0e 100644 (file)
@@ -6,8 +6,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session
 from sqlalchemy.test.testing import eq_
 from test.orm import _base
@@ -16,6 +15,11 @@ class NaturalPKTest(_base.MappedTest):
 
     @classmethod
     def define_tables(cls, metadata):
+        if testing.against('oracle'):
+            fk_args = dict(deferrable=True, initially='deferred')
+        else:
+            fk_args = dict(onupdate='cascade')
+            
         users = Table('users', metadata,
             Column('username', String(50), primary_key=True),
             Column('fullname', String(100)),
@@ -23,7 +27,7 @@ class NaturalPKTest(_base.MappedTest):
 
         addresses = Table('addresses', metadata,
             Column('email', String(50), primary_key=True),
-            Column('username', String(50), ForeignKey('users.username', onupdate="cascade")),
+            Column('username', String(50), ForeignKey('users.username', **fk_args)),
             test_needs_fk=True)
 
         items = Table('items', metadata,
@@ -32,8 +36,8 @@ class NaturalPKTest(_base.MappedTest):
             test_needs_fk=True)
 
         users_to_items = Table('users_to_items', metadata,
-            Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True),
-            Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True),
+            Column('username', String(50), ForeignKey('users.username', **fk_args), primary_key=True),
+            Column('itemname', String(50), ForeignKey('items.itemname', **fk_args), primary_key=True),
             test_needs_fk=True)
 
     @classmethod
@@ -110,6 +114,7 @@ class NaturalPKTest(_base.MappedTest):
         
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_onetomany_passive(self):
         self._test_onetomany(True)
 
@@ -161,6 +166,7 @@ class NaturalPKTest(_base.MappedTest):
         
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_manytoone_passive(self):
         self._test_manytoone(True)
 
@@ -203,6 +209,7 @@ class NaturalPKTest(_base.MappedTest):
         eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_onetoone_passive(self):
         self._test_onetoone(True)
 
@@ -244,6 +251,7 @@ class NaturalPKTest(_base.MappedTest):
         eq_([Address(username='ed')], sess.query(Address).all())
         
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_bidirectional_passive(self):
         self._test_bidirectional(True)
 
@@ -298,10 +306,12 @@ class NaturalPKTest(_base.MappedTest):
 
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_manytomany_passive(self):
         self._test_manytomany(True)
 
-    @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count')
+    # mysqldb executemany() of the association table fails to report the correct row count
+    @testing.fails_if(lambda: testing.against('mysql') and not testing.against('+zxjdbc'))
     def test_manytomany_nonpassive(self):
         self._test_manytomany(False)
 
@@ -361,10 +371,15 @@ class SelfRefTest(_base.MappedTest):
 
     @classmethod
     def define_tables(cls, metadata):
+        if testing.against('oracle'):
+            fk_args = dict(deferrable=True, initially='deferred')
+        else:
+            fk_args = dict(onupdate='cascade')
+        
         Table('nodes', metadata,
               Column('name', String(50), primary_key=True),
               Column('parent', String(50),
-                     ForeignKey('nodes.name', onupdate='cascade')))
+                     ForeignKey('nodes.name', **fk_args)))
 
     @classmethod
     def setup_classes(cls):
@@ -400,17 +415,22 @@ class SelfRefTest(_base.MappedTest):
 class NonPKCascadeTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
+        if testing.against('oracle'):
+            fk_args = dict(deferrable=True, initially='deferred')
+        else:
+            fk_args = dict(onupdate='cascade')
+
         Table('users', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('username', String(50), unique=True),
             Column('fullname', String(100)),
             test_needs_fk=True)
 
         Table('addresses', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('email', String(50)),
               Column('username', String(50),
-                     ForeignKey('users.username', onupdate="cascade")),
+                     ForeignKey('users.username', **fk_args)),
                      test_needs_fk=True
                      )
 
@@ -422,6 +442,7 @@ class NonPKCascadeTest(_base.MappedTest):
             pass
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_onetomany_passive(self):
         self._test_onetomany(True)
 
index 0d66915ea5d79230bf5cd45a61ec8f0c9dfe0c87..6880f1f747e4eff1e0edc82f5e6157b373ef8136 100644 (file)
@@ -1,8 +1,7 @@
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session
 from test.orm import _base
 
@@ -11,13 +10,13 @@ class O2OTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('jack', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('number', String(50)),
               Column('status', String(20)),
               Column('subroom', String(5)))
 
         Table('port', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(30)),
               Column('description', String(100)),
               Column('jack_id', Integer, ForeignKey("jack.id")))
index 5343cc15b940a201a1677c5b2624a1a83d301e67..6ac9f24701cd819de33f4a233fd7bed919f82fa8 100644 (file)
@@ -3,8 +3,7 @@ import pickle
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session, attributes
 from test.orm import _base, _fixtures
 
@@ -60,7 +59,7 @@ class PickleTest(_fixtures.FixtureTest):
 
         u2 = pickle.loads(pickle.dumps(u1))
         sess2 = create_session()
-        u2 = sess2.merge(u2, dont_load=True)
+        u2 = sess2.merge(u2, load=False)
         eq_(u2.name, 'ed')
         eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
 
@@ -94,7 +93,7 @@ class PickleTest(_fixtures.FixtureTest):
 
         u2 = pickle.loads(pickle.dumps(u1))
         sess2 = create_session()
-        u2 = sess2.merge(u2, dont_load=True)
+        u2 = sess2.merge(u2, load=False)
         eq_(u2.name, 'ed')
         assert 'addresses' not in u2.__dict__
         ad = u2.addresses[0]
@@ -136,7 +135,7 @@ class PolymorphicDeferredTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(30)),
             Column('type', String(30)))
         Table('email_users', metadata,
index 88a95bf7608a3eda14ac1f46eb679dd8c648b5bd..8cb7ef969bd676bfcea8ef574e7071490f66eb95 100644 (file)
@@ -109,13 +109,8 @@ class GetTest(QueryTest):
             pass
         s = users.select(users.c.id!=12).alias('users')
         m = mapper(SomeUser, s)
-        print s.primary_key
-        print m.primary_key
         assert s.primary_key == m.primary_key
 
-        row = s.select(use_labels=True).execute().fetchone()
-        print row[s.primary_key[0]]
-
         sess = create_session()
         assert sess.query(SomeUser).get(7).name == 'jack'
 
@@ -145,15 +140,20 @@ class GetTest(QueryTest):
     @testing.requires.unicode_connections
     def test_unicode(self):
         """test that Query.get properly sets up the type for the bind parameter.  using unicode would normally fail
-        on postgres, mysql and oracle unless it is converted to an encoded string"""
+        on postgresql, mysql and oracle unless it is converted to an encoded string"""
 
         metadata = MetaData(engines.utf8_engine())
         table = Table('unicode_data', metadata,
-            Column('id', Unicode(40), primary_key=True),
+            Column('id', Unicode(40), primary_key=True, test_needs_autoincrement=True),
             Column('data', Unicode(40)))
         try:
             metadata.create_all()
+            # Py3K
+            #ustring = 'petit voix m\xe2\x80\x99a'
+            # Py2K
             ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8')
+            # end Py2K
+            
             table.insert().execute(id=ustring, data=ustring)
             class LocalFoo(Base):
                 pass
@@ -195,7 +195,7 @@ class GetTest(QueryTest):
         assert u.addresses[0].email_address == 'jack@bean.com'
         assert u.orders[1].items[2].description == 'item 5'
 
-    @testing.fails_on_everything_except('sqlite', 'mssql')
+    @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc')
     def test_query_str(self):
         s = create_session()
         q = s.query(User).filter(User.id==1)
@@ -299,7 +299,12 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
     def test_arithmetic(self):
         create_session().query(User)
         for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
-                                (operator.sub, '-'), (operator.div, '/'),
+                                (operator.sub, '-'), 
+                                # Py3k
+                                #(operator.truediv, '/'),
+                                # Py2K
+                                (operator.div, '/'),
+                                # end Py2K
                                 ):
             for (lhs, rhs, res) in (
                 (5, User.id, ':id_1 %s users.id'),
@@ -489,10 +494,16 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         sess = create_session()
 
         self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, 
-            "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1")
+            "SELECT users.id AS users_id, users.name AS users_name FROM users, "
+            "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1",
+            dialect=default.DefaultDialect()
+            )
 
         self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, 
-            "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users")
+            "SELECT users.id AS users_id, users.name AS users_name, EXISTS "
+            "(SELECT 1 FROM addresses) AS anon_1 FROM users",
+            dialect=default.DefaultDialect()
+            )
 
         # a little tedious here, adding labels to work around Query's auto-labelling.
         # also correlate needed explicitly.  hmmm.....
@@ -687,15 +698,19 @@ class FilterTest(QueryTest):
         sess = create_session()
         assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all()
 
-        assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all()
+        assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == \
+                sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).order_by(Address.id).all()
 
-        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+        assert [Address(id=2), Address(id=3), Address(id=4)] == \
+            sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all()
 
         # test has() doesn't overcorrelate
-        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+        assert [Address(id=2), Address(id=3), Address(id=4)] == \
+            sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all()
 
         # test has() doesnt' get subquery contents adapted by aliased join
-        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+        assert [Address(id=2), Address(id=3), Address(id=4)] == \
+            sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all()
         
         dingaling = sess.query(Dingaling).get(2)
         assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all()
@@ -730,8 +745,8 @@ class FilterTest(QueryTest):
         assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all()
 
         # m2m
-        eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)])
-        eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)])
+        eq_(sess.query(Item).filter(Item.keywords==None).order_by(Item.id).all(), [Item(id=4), Item(id=5)])
+        eq_(sess.query(Item).filter(Item.keywords!=None).order_by(Item.id).all(), [Item(id=1),Item(id=2), Item(id=3)])
     
     def test_filter_by(self):
         sess = create_session()
@@ -748,8 +763,9 @@ class FilterTest(QueryTest):
         sess = create_session()
         
         # o2o
-        eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
-        eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
+        eq_([Address(id=1), Address(id=3), Address(id=4)], 
+            sess.query(Address).filter(Address.dingaling==None).order_by(Address.id).all())
+        eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).order_by(Address.id).all())
         
         # m2o
         eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
@@ -806,11 +822,15 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL):
         
         s = create_session()
         
+        oracle_as = not testing.against('oracle') and "AS " or ""
+        
         self.assert_compile(
             s.query(User).options(eagerload(User.addresses)).from_self().statement,
             "SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, addresses_1.user_id, "\
-            "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) AS anon_1 "\
-            "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id"
+            "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) %(oracle_as)sanon_1 "\
+            "LEFT OUTER JOIN addresses %(oracle_as)saddresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" % {
+                'oracle_as':oracle_as
+            }
         )
             
     def test_aliases(self):
@@ -987,8 +1007,14 @@ class CountTest(QueryTest):
         
 class DistinctTest(QueryTest):
     def test_basic(self):
-        assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all()
-        assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all()
+        eq_(
+            [User(id=7), User(id=8), User(id=9),User(id=10)],
+            create_session().query(User).order_by(User.id).distinct().all()
+        )
+        eq_(
+            [User(id=7), User(id=9), User(id=8),User(id=10)], 
+            create_session().query(User).distinct().order_by(desc(User.name)).all()
+        ) 
 
     def test_joined(self):
         """test that orderbys from a joined table get placed into the columns clause when DISTINCT is used"""
@@ -1017,7 +1043,6 @@ class DistinctTest(QueryTest):
 
 class YieldTest(QueryTest):
     def test_basic(self):
-        import gc
         sess = create_session()
         q = iter(sess.query(User).yield_per(1).from_statement("select * from users"))
 
@@ -1447,11 +1472,11 @@ class MultiplePathTest(_base.MappedTest):
     def define_tables(cls, metadata):
         global t1, t2, t1t2_1, t1t2_2
         t1 = Table('t1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30))
             )
         t2 = Table('t2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30))
             )
 
@@ -1715,7 +1740,7 @@ class MixedEntitiesTest(QueryTest):
         eq_(list(q2), [(u'jack',), (u'ed',)])
     
         q = sess.query(User)
-        q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String))
+        q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String(50)))
         eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')])
     
         q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address)
@@ -1755,6 +1780,9 @@ class MixedEntitiesTest(QueryTest):
         eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
 
     @testing.fails_on('mssql', 'FIXME: unknown')
+    @testing.fails_on('oracle', "Oracle doesn't support boolean expressions as columns")
+    @testing.fails_on('postgresql+pg8000', "pg8000 parses the SQL itself before passing on to PG, doesn't parse this")
+    @testing.fails_on('postgresql+zxjdbc', "zxjdbc parses the SQL itself before passing on to PG, doesn't parse this")
     def test_values_with_boolean_selects(self):
         """Tests a values clause that works with select boolean evaluations"""
         sess = create_session()
@@ -1763,6 +1791,10 @@ class MixedEntitiesTest(QueryTest):
         q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%')))
         eq_(list(q2), [(True, 1), (False, 3)])
 
+        q2 = q.order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'))
+        eq_(list(q2), [(True,), (False,), (False,), (False,)])
+
+
     def test_correlated_subquery(self):
         """test that a subquery constructed from ORM attributes doesn't leak out 
         those entities to the outermost query.
@@ -2514,12 +2546,14 @@ class SelfReferentialTest(_base.MappedTest):
         sess = create_session()
         eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
         eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
-        eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+        eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id).all(), 
+                [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
 
     def test_has(self):
         sess = create_session()
     
-        eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+        eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).order_by(Node.id).all(), 
+            [Node(data='n121'),Node(data='n122'),Node(data='n123')])
         eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
         eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
 
@@ -2660,7 +2694,7 @@ class ExternalColumnsTest(QueryTest):
         for x in range(2):
             sess.expunge_all()
             def go():
-               eq_(sess.query(Address).options(eagerload('user')).all(), address_result)
+               eq_(sess.query(Address).options(eagerload('user')).order_by(Address.id).all(), address_result)
             self.assert_sql_count(testing.db, go, 1)
     
         ualias = aliased(User)
@@ -2691,7 +2725,9 @@ class ExternalColumnsTest(QueryTest):
         )
 
         ua = aliased(User)
-        eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+        eq_(sess.query(Address, ua.concat, ua.count).
+                    select_from(join(Address, ua, 'user')).
+                    options(eagerload(Address.user)).order_by(Address.id).all(),
             [
                 (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
                 (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
@@ -2742,7 +2778,7 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest):
     def define_tables(cls, metadata):
         global base, sub1, sub2
         base = Table('base', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50))
         )
 
@@ -2800,12 +2836,12 @@ class UpdateDeleteTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(32)),
               Column('age', Integer))
 
         Table('documents', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('user_id', None, ForeignKey('users.id')),
               Column('title', String(32)))
 
@@ -2875,7 +2911,7 @@ class UpdateDeleteTest(_base.MappedTest):
         sess = create_session(bind=testing.db, autocommit=False)
 
         john,jack,jill,jane = sess.query(User).order_by(User.id).all()
-        sess.query(User).filter('name = :name').params(name='john').delete()
+        sess.query(User).filter('name = :name').params(name='john').delete('fetch')
         assert john not in sess
 
         eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane])
@@ -2922,12 +2958,17 @@ class UpdateDeleteTest(_base.MappedTest):
 
     @testing.fails_on('mysql', 'FIXME: unknown')
     @testing.resolve_artifact_names
-    def test_delete_fallback(self):
+    def test_delete_invalid_evaluation(self):
         sess = create_session(bind=testing.db, autocommit=False)
     
         john,jack,jill,jane = sess.query(User).order_by(User.id).all()
-        sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='evaluate')
     
+        assert_raises(sa_exc.InvalidRequestError,
+            sess.query(User).filter(User.name == select([func.max(User.name)])).delete, synchronize_session='evaluate'
+        )
+        
+        sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='fetch')
+        
         assert john not in sess
     
         eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane])
@@ -2957,7 +2998,7 @@ class UpdateDeleteTest(_base.MappedTest):
 
         john,jack,jill,jane = sess.query(User).order_by(User.id).all()
 
-        sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='evaluate')
+        sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='fetch')
 
         eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
         eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
@@ -3017,11 +3058,12 @@ class UpdateDeleteTest(_base.MappedTest):
         sess = create_session(bind=testing.db, autocommit=False)
     
         john,jack,jill,jane = sess.query(User).order_by(User.id).all()
-        sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire')
+        sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch')
     
         eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
         eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
 
+    @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
     @testing.resolve_artifact_names
     def test_update_returns_rowcount(self):
         sess = create_session(bind=testing.db, autocommit=False)
@@ -3032,6 +3074,7 @@ class UpdateDeleteTest(_base.MappedTest):
         rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
         eq_(rowcount, 2)
 
+    @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
     @testing.resolve_artifact_names
     def test_delete_returns_rowcount(self):
         sess = create_session(bind=testing.db, autocommit=False)
@@ -3046,7 +3089,7 @@ class UpdateDeleteTest(_base.MappedTest):
         sess = create_session(bind=testing.db, autocommit=False)
 
         foo,bar,baz = sess.query(Document).order_by(Document.id).all()
-        sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='expire')
+        sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='fetch')
 
         eq_([foo.title, bar.title, baz.title], ['foofoo','barbar', 'baz'])
         eq_(sess.query(Document.title).order_by(Document.id).all(), zip(['foofoo','barbar', 'baz']))
@@ -3056,7 +3099,7 @@ class UpdateDeleteTest(_base.MappedTest):
         sess = create_session(bind=testing.db, autocommit=False)
 
         john,jack,jill,jane = sess.query(User).order_by(User.id).all()
-        sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire')
+        sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch')
 
         eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
         eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
index fef1577f0757c5e3e334d1b84732e329b315eae9..481deb81b1ece7f9be887ffad6f18009201979cd 100644 (file)
@@ -3,8 +3,7 @@ import datetime
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey, MetaData, and_
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers, sessionmaker
 from sqlalchemy.test.testing import eq_, startswith_
 from test.orm import _base, _fixtures
@@ -32,17 +31,17 @@ class RelationTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("tbl_a", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("name", String(128)))
         Table("tbl_b", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("name", String(128)))
         Table("tbl_c", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False),
             Column("name", String(128)))
         Table("tbl_d", metadata,
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False),
             Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
             Column("name", String(128)))
@@ -132,7 +131,7 @@ class RelationTest2(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('company_t', metadata,
-              Column('company_id', Integer, primary_key=True),
+              Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', sa.Unicode(30)))
 
         Table('employee_t', metadata,
@@ -395,7 +394,7 @@ class RelationTest4(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("tableA", metadata,
-              Column("id",Integer,primary_key=True),
+              Column("id",Integer,primary_key=True, test_needs_autoincrement=True),
               Column("foo",Integer,),
               test_needs_fk=True)
         Table("tableB",metadata,
@@ -456,7 +455,7 @@ class RelationTest4(_base.MappedTest):
     @testing.fails_on_everything_except('sqlite', 'mysql')
     @testing.resolve_artifact_names
     def test_nullPKsOK_BtoA(self):
-        # postgres cant handle a nullable PK column...?
+        # postgresql cant handle a nullable PK column...?
         tableC = Table('tablec', tableA.metadata,
             Column('id', Integer, primary_key=True),
             Column('a_id', Integer, ForeignKey('tableA.id'),
@@ -642,12 +641,12 @@ class RelationTest6(_base.MappedTest):
     
     @classmethod
     def define_tables(cls, metadata):
-        Table('tags', metadata, Column("id", Integer, primary_key=True),
+        Table('tags', metadata, Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column("data", String(50)),
         )
 
         Table('tag_foo', metadata, 
-            Column("id", Integer, primary_key=True),
+            Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
             Column('tagid', Integer),
             Column("data", String(50)),
         )
@@ -691,11 +690,11 @@ class BackrefPropagatesForwardsArgs(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(50))
         )
         Table('addresses', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer),
             Column('email', String(50))
         )
@@ -738,7 +737,7 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         subscriber_table = Table('subscriber', metadata,
-           Column('id', Integer, primary_key=True),
+           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('dummy', String(10)) # to appease older sqlite version
           )
 
@@ -947,18 +946,18 @@ class TypeMatchTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("a", metadata,
-              Column('aid', Integer, primary_key=True),
+              Column('aid', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)))
         Table("b", metadata,
-               Column('bid', Integer, primary_key=True),
+               Column('bid', Integer, primary_key=True, test_needs_autoincrement=True),
                Column("a_id", Integer, ForeignKey("a.aid")),
                Column('data', String(30)))
         Table("c", metadata,
-              Column('cid', Integer, primary_key=True),
+              Column('cid', Integer, primary_key=True, test_needs_autoincrement=True),
               Column("b_id", Integer, ForeignKey("b.bid")),
               Column('data', String(30)))
         Table("d", metadata,
-              Column('did', Integer, primary_key=True),
+              Column('did', Integer, primary_key=True, test_needs_autoincrement=True),
               Column("a_id", Integer, ForeignKey("a.aid")),
               Column('data', String(30)))
 
@@ -1116,14 +1115,14 @@ class ViewOnlyOverlappingNames(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("t1", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)))
         Table("t2", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)),
             Column('t1id', Integer, ForeignKey('t1.id')))
         Table("t3", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)),
             Column('t2id', Integer, ForeignKey('t2.id')))
 
@@ -1176,14 +1175,14 @@ class ViewOnlyUniqueNames(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("t1", metadata,
-            Column('t1id', Integer, primary_key=True),
+            Column('t1id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)))
         Table("t2", metadata,
-            Column('t2id', Integer, primary_key=True),
+            Column('t2id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)),
             Column('t1id_ref', Integer, ForeignKey('t1.t1id')))
         Table("t3", metadata,
-            Column('t3id', Integer, primary_key=True),
+            Column('t3id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(40)),
             Column('t2id_ref', Integer, ForeignKey('t2.t2id')))
 
@@ -1309,12 +1308,12 @@ class ViewOnlyRepeatedRemoteColumn(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('foos', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('bid1', Integer,ForeignKey('bars.id')),
               Column('bid2', Integer,ForeignKey('bars.id')))
 
         Table('bars', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)))
 
     @testing.resolve_artifact_names
@@ -1357,10 +1356,10 @@ class ViewOnlyRepeatedLocalColumn(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('foos', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(50)))
 
-        Table('bars', metadata, Column('id', Integer, primary_key=True),
+        Table('bars', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('fid1', Integer, ForeignKey('foos.id')),
               Column('fid2', Integer, ForeignKey('foos.id')),
               Column('data', String(50)))
@@ -1405,14 +1404,14 @@ class ViewOnlyComplexJoin(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)))
         Table('t2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)),
             Column('t1id', Integer, ForeignKey('t1.id')))
         Table('t3', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)))
         Table('t2tot3', metadata,
             Column('t2id', Integer, ForeignKey('t2.id')),
@@ -1476,10 +1475,10 @@ class ExplicitLocalRemoteTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('t1', metadata,
-            Column('id', String(50), primary_key=True),
+            Column('id', String(50), primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)))
         Table('t2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)),
             Column('t1id', String(50)))
 
index 9f2f59e19b42db526f5d5920b489b73159b02fb1..0d6b3deaecbbdc5fee2f83771e6aa0c4070ac013 100644 (file)
@@ -3,13 +3,13 @@ import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy.orm import scoped_session
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, query
 from sqlalchemy.test.testing import eq_
 from test.orm import _base
 
 
+
 class _ScopedTest(_base.MappedTest):
     """Adds another lookup bucket to emulate Session globals."""
 
@@ -34,10 +34,10 @@ class ScopedSessionTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('table1', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)))
         Table('table2', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('someid', None, ForeignKey('table1.id')))
 
     @testing.resolve_artifact_names
@@ -82,10 +82,10 @@ class ScopedMapperTest(_ScopedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('table1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)))
         Table('table2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('someid', None, ForeignKey('table1.id')))
 
     @classmethod
@@ -204,11 +204,11 @@ class ScopedMapperTest2(_ScopedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('table1', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('type', String(30)))
         Table('table2', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('someid', None, ForeignKey('table1.id')),
             Column('somedata', String(30)))
 
index 0a20253607772a0b9348ac693e6b0b1696e122b6..bfa40089573e0f0418a0abc34f990f7140b167f4 100644 (file)
@@ -3,8 +3,7 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import String, Integer, select
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, create_session
 from sqlalchemy.test.testing import eq_
 from test.orm import _base
@@ -16,7 +15,7 @@ class SelectableNoFromsTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('common', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', Integer),
               Column('extra', String(45)))
 
index 328cbee8ee991dd685ddaeb9b91d1cc226a281ef..2d99e20630ac88a729a9dfb2ae4663c2850a2a79 100644 (file)
@@ -1,13 +1,12 @@
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-import gc
+from sqlalchemy.test.util import gc_collect
 import inspect
 import pickle
 from sqlalchemy.orm import create_session, sessionmaker, attributes
 import sqlalchemy as sa
 from sqlalchemy.test import engines, testing, config
 from sqlalchemy import Integer, String, Sequence
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, backref, eagerload
 from sqlalchemy.test.testing import eq_
 from test.engine import _base as engine_base
@@ -229,7 +228,7 @@ class SessionTest(_fixtures.FixtureTest):
         u = sess.query(User).get(u.id)
         q = sess.query(Address).filter(Address.user==u)
         del u
-        gc.collect()
+        gc_collect()
         eq_(q.one(), Address(email_address='foo'))
 
 
@@ -381,18 +380,18 @@ class SessionTest(_fixtures.FixtureTest):
         session = create_session(bind=testing.db)
 
         session.begin()
-        session.connection().execute("insert into users (name) values ('user1')")
+        session.connection().execute(users.insert().values(name='user1'))
 
         session.begin(subtransactions=True)
 
         session.begin_nested()
 
-        session.connection().execute("insert into users (name) values ('user2')")
+        session.connection().execute(users.insert().values(name='user2'))
         assert session.connection().execute("select count(1) from users").scalar() == 2
 
         session.rollback()
         assert session.connection().execute("select count(1) from users").scalar() == 1
-        session.connection().execute("insert into users (name) values ('user3')")
+        session.connection().execute(users.insert().values(name='user3'))
 
         session.commit()
         assert session.connection().execute("select count(1) from users").scalar() == 2
@@ -771,18 +770,18 @@ class SessionTest(_fixtures.FixtureTest):
 
         user = s.query(User).one()
         del user
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 0
 
         user = s.query(User).one()
         user.name = 'fred'
         del user
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 1
         assert len(s.dirty) == 1
         assert None not in s.dirty
         s.flush()
-        gc.collect()
+        gc_collect()
         assert not s.dirty
         assert not s.identity_map
 
@@ -809,13 +808,13 @@ class SessionTest(_fixtures.FixtureTest):
         s.add(u2)
         
         del u2
-        gc.collect()
+        gc_collect()
         
         assert len(s.identity_map) == 1
         assert len(s.dirty) == 1
         assert None not in s.dirty
         s.flush()
-        gc.collect()
+        gc_collect()
         assert not s.dirty
         
         assert not s.identity_map
@@ -835,14 +834,14 @@ class SessionTest(_fixtures.FixtureTest):
         eq_(user, User(name="ed", addresses=[Address(email_address="ed1")]))
         
         del user
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 0
 
         user = s.query(User).options(eagerload(User.addresses)).one()
         user.addresses[0].email_address='ed2'
         user.addresses[0].user # lazyload
         del user
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 2
         
         s.commit()
@@ -864,7 +863,7 @@ class SessionTest(_fixtures.FixtureTest):
         eq_(user, User(name="ed", address=Address(email_address="ed1")))
 
         del user
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 0
 
         user = s.query(User).options(eagerload(User.address)).one()
@@ -872,7 +871,7 @@ class SessionTest(_fixtures.FixtureTest):
         user.address.user # lazyload
 
         del user
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 2
         
         s.commit()
@@ -890,8 +889,7 @@ class SessionTest(_fixtures.FixtureTest):
         user = s.query(User).one()
         user = None
         print s.identity_map
-        import gc
-        gc.collect()
+        gc_collect()
         assert len(s.identity_map) == 1
 
         user = s.query(User).one()
@@ -901,7 +899,7 @@ class SessionTest(_fixtures.FixtureTest):
         s.flush()
         eq_(users.select().execute().fetchall(), [(user.id, 'u2')])
         
-        
+    @testing.fails_on('+zxjdbc', 'http://www.sqlalchemy.org/trac/ticket/1473')
     @testing.resolve_artifact_names
     def test_prune(self):
         s = create_session(weak_identity_map=False)
@@ -914,8 +912,7 @@ class SessionTest(_fixtures.FixtureTest):
         self.assert_(len(s.identity_map) == 0)
         self.assert_(s.prune() == 0)
         s.flush()
-        import gc
-        gc.collect()
+        gc_collect()
         self.assert_(s.prune() == 9)
         self.assert_(len(s.identity_map) == 1)
 
@@ -1228,7 +1225,7 @@ class DisposedStates(_base.MappedTest):
     def define_tables(cls, metadata):
         global t1
         t1 = Table('t1', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50))
             )
 
@@ -1327,7 +1324,7 @@ class SessionInterface(testing.TestBase):
 
     def _map_it(self, cls):
         return mapper(cls, Table('t', sa.MetaData(),
-                                 Column('id', Integer, primary_key=True)))
+                                 Column('id', Integer, primary_key=True, test_needs_autoincrement=True)))
 
     @testing.uses_deprecated()
     def _test_instance_guards(self, user_arg):
@@ -1447,7 +1444,7 @@ class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(20)),
               test_needs_acid=True)
 
index 5aa541cdadaf7b4ec324f2f7c6d07003ae8c577a..51b345cebd96deb7563c4787dfc82be40d0f7ac0 100644 (file)
@@ -4,13 +4,11 @@ from sqlalchemy import *
 from sqlalchemy.orm import attributes
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
-
+from sqlalchemy.test.util import gc_collect
 from sqlalchemy.test import testing
 from test.orm import _base
 from test.orm._fixtures import FixtureTest, User, Address, users, addresses
 
-import gc
-
 class TransactionTest(FixtureTest):
     run_setup_mappers = 'once'
     run_inserts = None
@@ -20,7 +18,7 @@ class TransactionTest(FixtureTest):
     def setup_mappers(cls):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user',
-                                 cascade="all, delete-orphan"),
+                                 cascade="all, delete-orphan", order_by=addresses.c.id),
             })
         mapper(Address, addresses)
 
@@ -109,7 +107,7 @@ class AutoExpireTest(TransactionTest):
         assert u1_state not in s.identity_map.all_states()
         assert u1_state not in s._deleted
         del u1
-        gc.collect()
+        gc_collect()
         assert u1_state.obj() is None
         
         s.rollback()
index f95346902be8ca91a7c0ded35cdc1dc6202ea51a..4d2056b264f26c491216a814471a788d5b291d20 100644 (file)
@@ -379,7 +379,6 @@ class MutableTypesTest(_base.MappedTest):
              "WHERE mutable_t.id = :mutable_t_id",
              {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})])
 
-
     @testing.resolve_artifact_names
     def test_resurrect(self):
         f1 = Foo()
@@ -392,42 +391,13 @@ class MutableTypesTest(_base.MappedTest):
 
         f1.data.y = 19
         del f1
-        
+
         gc.collect()
         assert len(session.identity_map) == 1
-        
-        session.commit()
-        
-        assert session.query(Foo).one().data == pickleable.Bar(4, 19)
-        
-        
-    @testing.uses_deprecated()
-    @testing.resolve_artifact_names
-    def test_nocomparison(self):
-        """Changes are detected on MutableTypes lacking an __eq__ method."""
 
-        f1 = Foo()
-        f1.data = pickleable.BarWithoutCompare(4,5)
-        session = create_session(autocommit=False)
-        session.add(f1)
         session.commit()
 
-        self.sql_count_(0, session.commit)
-        session.close()
-
-        session = create_session(autocommit=False)
-        f2 = session.query(Foo).filter_by(id=f1.id).one()
-        self.sql_count_(0, session.commit)
-
-        f2.data.y = 19
-        self.sql_count_(1, session.commit)
-        session.close()
-
-        session = create_session(autocommit=False)
-        f3 = session.query(Foo).filter_by(id=f1.id).one()
-        eq_((f3.data.x, f3.data.y), (4,19))
-        self.sql_count_(0, session.commit)
-        session.close()
+        assert session.query(Foo).one().data == pickleable.Bar(4, 19)
 
     @testing.resolve_artifact_names
     def test_unicode(self):
@@ -892,7 +862,7 @@ class DefaultTest(_base.MappedTest):
 
     @classmethod
     def define_tables(cls, metadata):
-        use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql')
+        use_string_defaults = testing.against('postgresql', 'oracle', 'sqlite', 'mssql')
 
         if use_string_defaults:
             hohotype = String(30)
@@ -910,15 +880,14 @@ class DefaultTest(_base.MappedTest):
             Column('id', Integer, primary_key=True,
                    test_needs_autoincrement=True),
             Column('hoho', hohotype, server_default=str(hohoval)),
-            Column('counter', Integer, default=sa.func.char_length("1234567")),
-            Column('foober', String(30), default="im foober",
-                   onupdate="im the update"))
+            Column('counter', Integer, default=sa.func.char_length("1234567", type_=Integer)),
+            Column('foober', String(30), default="im foober", onupdate="im the update"))
 
         st = Table('secondary_table', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)))
 
-        if testing.against('postgres', 'oracle'):
+        if testing.against('postgresql', 'oracle'):
             dt.append_column(
                 Column('secondary_id', Integer, sa.Sequence('sec_id_seq'),
                        unique=True))
@@ -1004,14 +973,14 @@ class DefaultTest(_base.MappedTest):
         # "post-update"
         mapper(Hoho, default_t)
 
-        h1 = Hoho(hoho="15", counter="15")
+        h1 = Hoho(hoho="15", counter=15)
         session = create_session()
         session.add(h1)
         session.flush()
 
         def go():
             eq_(h1.hoho, "15")
-            eq_(h1.counter, "15")
+            eq_(h1.counter, 15)
             eq_(h1.foober, "im foober")
         self.sql_count_(0, go)
 
@@ -1036,7 +1005,7 @@ class DefaultTest(_base.MappedTest):
         """A server-side default can be used as the target of a foreign key"""
 
         mapper(Hoho, default_t, properties={
-            'secondaries':relation(Secondary)})
+            'secondaries':relation(Secondary, order_by=secondary_table.c.id)})
         mapper(Secondary, secondary_table)
 
         h1 = Hoho()
@@ -1068,7 +1037,7 @@ class ColumnPropertyTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('data', metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('a', String(50)),
             Column('b', String(50))
             )
@@ -1681,7 +1650,7 @@ class ManyToOneTest(_fixtures.FixtureTest):
         l = sa.select([users, addresses],
                       sa.and_(users.c.id==addresses.c.user_id,
                               addresses.c.id==a.id)).execute()
-        eq_(l.fetchone().values(),
+        eq_(l.first().values(),
             [a.user.id, 'asdf8d', a.id, a.user_id, 'theater@foo.com'])
 
     @testing.resolve_artifact_names
@@ -2201,8 +2170,14 @@ class RowSwitchTest(_base.MappedTest):
         sess.add(o5)
         sess.flush()
 
-        assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')]
-        assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1), (2, 'some other t6', 1)]
+        eq_(
+            list(sess.execute(t5.select(), mapper=T5)),
+            [(1, 'some t5')]
+        )
+        eq_(
+            list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)),
+            [(1, 'some t6', 1), (2, 'some other t6', 1)]
+        )
 
         o6 = T5(data='some other t5', id=o5.id, t6s=[
             T6(data='third t6', id=3),
@@ -2212,8 +2187,14 @@ class RowSwitchTest(_base.MappedTest):
         sess.add(o6)
         sess.flush()
 
-        assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')]
-        assert list(sess.execute(t6.select(), mapper=T5)) == [(3, 'third t6', 1), (4, 'fourth t6', 1)]
+        eq_(
+            list(sess.execute(t5.select(), mapper=T5)),
+            [(1, 'some other t5')]
+        )
+        eq_(
+            list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)),
+            [(3, 'third t6', 1), (4, 'fourth t6', 1)]
+        )
 
     @testing.resolve_artifact_names
     def test_manytomany(self):
@@ -2369,6 +2350,6 @@ class TransactionTest(_base.MappedTest):
         # todo: on 8.3 at least, the failed commit seems to close the cursor?
         # needs investigation.  leaving in the DDL above now to help verify
         # that the new deferrable support on FK isn't involved in this issue.
-        if testing.against('postgres'):
+        if testing.against('postgresql'):
             t1.bind.engine.dispose()
 
index 06533a243b20245ffaa7194b9f2eb4fec09935ff..8635ad2125e0718358d249f933ba2ebe5d62d848 100644 (file)
@@ -39,8 +39,13 @@ class ExtensionCarrierTest(TestBase):
 
         assert 'populate_instance' not in carrier
         carrier.append(interfaces.MapperExtension)
+        
+        # Py3K
+        #assert 'populate_instance' not in carrier
+        # Py2K
         assert 'populate_instance' in carrier
-
+        # end Py2K
+        
         assert carrier.interface
         for m in carrier.interface:
             assert getattr(interfaces.MapperExtension, m)
@@ -85,7 +90,10 @@ class AliasedClassTest(TestBase):
         alias = aliased(Point)
 
         assert Point.zero
+        # Py2K
+        # TODO: what is this testing ??
         assert not getattr(alias, 'zero')
+        # end Py2K
 
     def test_classmethods(self):
         class Point(object):
@@ -152,9 +160,17 @@ class AliasedClassTest(TestBase):
                 self.func = func
             def __get__(self, instance, owner):
                 if instance is None:
+                    # Py3K
+                    #args = (self.func, owner)
+                    # Py2K
                     args = (self.func, owner, owner.__class__)
+                    # end Py2K
                 else:
+                    # Py3K
+                    #args = (self.func, instance)
+                    # Py2K
                     args = (self.func, instance, owner)
+                    # end Py2K
                 return types.MethodType(*args)
 
         class PropertyDescriptor(object):
index 32877560eb98708fb6865529c7fd45c77bc06719..0491e9f95950ac81bfc69500c521ce11395c73d7 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.simple_setup()
 import sys, time
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import profiling
+from sqlalchemy.test import profiling
 
 db = create_engine('sqlite://')
 metadata = MetaData(db)
index ae32f83e2cbbbed3936ac8517d6a308961296f38..5b8e0da555adb9f5b6ceeab457d3408ac53cbada 100644 (file)
@@ -3,7 +3,6 @@ import testenv; testenv.simple_setup()
 
 from sqlalchemy.orm import attributes
 import time
-import gc
 
 manage_attributes = True
 init_attributes = manage_attributes and True
@@ -34,7 +33,6 @@ for i in range(0,130):
             attributes.manage(a)
         a.email = 'foo@bar.com'
         u.addresses.append(a)
-#    gc.collect()
 #    print len(managed_attributes)
 #    managed_attributes.clear()
 total = time.time() - now
index 25d4b49153cdccdc4e36167d6b2e20486eeb7985..e525fcf99d49750266adcfcb1fcb34ad29b9e052 100644 (file)
@@ -1,9 +1,9 @@
 import testenv; testenv.simple_setup()
-import gc
 
 import random, string
 
 from sqlalchemy.orm import attributes
+from sqlalchemy.test.util import gc_collect
 
 # with this test, run top.  make sure the Python process doenst grow in size arbitrarily.
 
@@ -33,4 +33,4 @@ for i in xrange(1000):
       a.user = u
   print "clearing"
   #managed_attributes.clear()
-  gc.collect()
+  gc_collect()
index a848b866cc8ef1fb376343bdf0275366484209ee..88a3ade20b19fe3ffabfb1262c72eda6c10ef587 100644 (file)
@@ -1,7 +1,6 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
 
 NUM = 500
 DIVISOR = 50
index 9391ead2a54190e31b889cceb4a623232a46ae5d..f6cde3adfdce8a783ee16ca695db4c10a0e21eb5 100644 (file)
@@ -1,10 +1,8 @@
-import testenv; testenv.configure_for_tests()
 import time
-#import gc
 #import sqlalchemy.orm.attributes as attributes
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
 
 """
 
@@ -18,16 +16,18 @@ top while it runs
 NUM = 2500
 
 class LoadTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global items, meta
         meta = MetaData(testing.db)
         items = Table('items', meta,
             Column('item_id', Integer, primary_key=True),
             Column('value', String(100)))
         items.create()
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         items.drop()
-    def setUp(self):
+    def setup(self):
         for x in range(1,NUM/500+1):
             l = []
             for y in range(x*500-500 + 1, x*500 + 1):
@@ -43,7 +43,7 @@ class LoadTest(TestBase, AssertsExecutionResults):
         query = sess.query(Item)
         for x in range (1,NUM/100):
             # this is not needed with cpython which clears non-circular refs immediately
-            #gc.collect()
+            #gc_collect()
             l = query.filter(items.c.item_id.between(x*100 - 100 + 1, x*100)).all()
             assert len(l) == 100
             print "loaded ", len(l), " items "
@@ -61,5 +61,3 @@ class LoadTest(TestBase, AssertsExecutionResults):
         print "total time ", total
 
 
-if __name__ == "__main__":
-    testenv.main()
index bf65c8fdf70024a66141087c9714d36021dc5838..41acd12ccfb7ccef9d566759d709a30518f99e26 100644 (file)
@@ -1,21 +1,23 @@
-import testenv; testenv.configure_for_tests()
+import gc
 import types
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
 
 
 NUM = 2500
 
 class SaveTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global items, metadata
         metadata = MetaData(testing.db)
         items = Table('items', metadata,
             Column('item_id', Integer, primary_key=True),
             Column('value', String(100)))
         items.create()
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         clear_mappers()
         metadata.drop_all()
 
@@ -50,5 +52,3 @@ class SaveTest(TestBase, AssertsExecutionResults):
             print x
 
 
-if __name__ == "__main__":
-    testenv.main()
index 896fd4c49472433b9d47caf241ed3808224b3aaa..867a396f353625a4da67249e28dde05d19da344b 100644 (file)
@@ -1,7 +1,8 @@
 import testenv; testenv.simple_setup()
-import time, gc, resource
+import time, resource
 from sqlalchemy import *
 from sqlalchemy.orm import *
+from sqlalchemy.test.util import gc_collect
 
 
 db = create_engine('sqlite://')
@@ -68,35 +69,35 @@ def all():
         usage.snap = lambda stats=None: setattr(
             usage, 'last', stats or resource.getrusage(resource.RUSAGE_SELF))
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         sqlite_select(RawPerson)
         t2 = time.clock()
         usage('sqlite select/native')
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         sqlite_select(Person)
         t2 = time.clock()
         usage('sqlite select/instrumented')
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         sql_select(RawPerson)
         t2 = time.clock()
         usage('sqlalchemy.sql select/native')
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         sql_select(Person)
         t2 = time.clock()
         usage('sqlalchemy.sql select/instrumented')
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         orm_select()
index a49eb472452b9e9dbe067f175cc4983c7ddd20d4..52224211ae4e51d78242a9076a4b54d62240b86a 100644 (file)
@@ -1,8 +1,9 @@
-import testenv; testenv.configure_for_tests()
-import time, gc, resource
+import time, resource
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
+from sqlalchemy.test.util import gc_collect
+
 
 NUM = 100
 
@@ -72,14 +73,14 @@ def all():
 
         session = create_session()
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         people = orm_select(session)
         t2 = time.clock()
         usage('load objects')
 
-        gc.collect()
+        gc_collect()
         usage.snap()
         t = time.clock()
         update_and_flush(session, people)
index cdffa51a96ae60a55f26600edf3abfb4dc212a79..f9f9dee8b7b3fda16a09c142f60e8b1aff0f4efa 100644 (file)
@@ -1,11 +1,10 @@
-import testenv; testenv.configure_for_tests()
 import time
 from datetime import datetime
 
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
-from testlib.profiling import profiled
+from sqlalchemy.test import *
+from sqlalchemy.test.profiling import profiled
 
 class Item(object):
     def __repr__(self):
index 8d66da84f4fd8984a7e86e81dda3f02a22433e70..62c66fbae67d1e4c83c75ec68161fd42d18ee97d 100644 (file)
@@ -1,9 +1,8 @@
 # load test of connection pool
-import testenv; testenv.configure_for_tests()
 import thread, time
 from sqlalchemy import *
 import sqlalchemy.pool as pool
-from testlib import testing
+from sqlalchemy.test import testing
 
 db = create_engine(testing.db.url, pool_timeout=30, echo_pool=True)
 metadata = MetaData(db)
index f4be1ee9369f2cfbd2cbf6ff2736523781f749a2..0d4cc1f014afc83f505c86d20816ea8f32804baf 100644 (file)
@@ -1,17 +1,17 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-import gc
 
-from testlib import TestBase, AssertsExecutionResults, profiling, testing
-from orm import _fixtures
+from sqlalchemy.test.compat import gc_collect
+from sqlalchemy.test import TestBase, AssertsExecutionResults, profiling, testing
+from test.orm import _fixtures
 
 # in this test we are specifically looking for time spent in the attributes.InstanceState.__cleanup() method.
 
 ITERATIONS = 100
 
 class SessionTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t1, t2, metadata,T1, T2
         metadata = MetaData(testing.db)
         t1 = Table('t1', metadata,
@@ -46,7 +46,8 @@ class SessionTest(TestBase, AssertsExecutionResults):
         })
         mapper(T2, t2)
     
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
         clear_mappers()
         
@@ -60,7 +61,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
 
             sess.close()
             del sess
-            gc.collect()
+            gc_collect()
 
     @profiling.profiled('dirty', report=True)
     def test_session_dirty(self):
@@ -74,11 +75,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
                     t2.c2 = 'this is some modified text'
 
             del t1s
-            gc.collect()
+            gc_collect()
             
             sess.close()
             del sess
-            gc.collect()
+            gc_collect()
 
     @profiling.profiled('noclose', report=True)
     def test_session_noclose(self):
@@ -89,9 +90,6 @@ class SessionTest(TestBase, AssertsExecutionResults):
                 t1s[index].t2s
 
             del sess
-            gc.collect()
-        
+            gc_collect()
 
 
-if __name__ == '__main__':
-    testenv.main()
index 6fc8149bcd6eb9a3e84da98cc9d86cc6b9e0e44c..549c92ade8ae0b9ed4ceb7d4df00ec111af0cad1 100644 (file)
@@ -1,11 +1,10 @@
 #!/usr/bin/python
 """Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop."""
 
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 import thread
-from testlib import *
+from sqlalchemy.test import *
 
 port = 8000
 
index 8abeb3533817b535393e524fe02f72795aac013d..4ad52604d32e508ebec642b3412395cd8e39e61e 100644 (file)
@@ -1,10 +1,14 @@
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 from sqlalchemy import *
-from sqlalchemy import exc
+from sqlalchemy import exc, schema
 from sqlalchemy.test import *
 from sqlalchemy.test import config, engines
+from sqlalchemy.engine import ddl
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL
+from sqlalchemy.dialects.postgresql import base as postgresql
 
-class ConstraintTest(TestBase, AssertsExecutionResults):
+class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
 
     def setup(self):
         global metadata
@@ -33,11 +37,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
     def test_double_fk_usage_raises(self):
         f = ForeignKey('b.id')
         
-        assert_raises(exc.InvalidRequestError, Table, "a", metadata,
-            Column('x', Integer, f),
-            Column('y', Integer, f)
-        )
-        
+        Column('x', Integer, f)
+        assert_raises(exc.InvalidRequestError, Column, "y", Integer, f)
         
     def test_circular_constraint(self):
         a = Table("a", metadata,
@@ -78,18 +79,9 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
 
         metadata.create_all()
         foo.insert().execute(id=1,x=9,y=5)
-        try:
-            foo.insert().execute(id=2,x=5,y=9)
-            assert False
-        except exc.SQLError:
-            assert True
-
+        assert_raises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9)
         bar.insert().execute(id=1,x=10)
-        try:
-            bar.insert().execute(id=2,x=5)
-            assert False
-        except exc.SQLError:
-            assert True
+        assert_raises(exc.SQLError, bar.insert().execute, id=2,x=5)
 
     def test_unique_constraint(self):
         foo = Table('foo', metadata,
@@ -106,16 +98,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         foo.insert().execute(id=2, value='value2')
         bar.insert().execute(id=1, value='a', value2='a')
         bar.insert().execute(id=2, value='a', value2='b')
-        try:
-            foo.insert().execute(id=3, value='value1')
-            assert False
-        except exc.SQLError:
-            assert True
-        try:
-            bar.insert().execute(id=3, value='a', value2='b')
-            assert False
-        except exc.SQLError:
-            assert True
+        assert_raises(exc.SQLError, foo.insert().execute, id=3, value='value1')
+        assert_raises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b')
 
     def test_index_create(self):
         employees = Table('employees', metadata,
@@ -174,35 +158,22 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         Index('sport_announcer', events.c.sport, events.c.announcer, unique=True)
         Index('idx_winners', events.c.winner)
 
-        index_names = [ ix.name for ix in events.indexes ]
-        assert 'ix_events_name' in index_names
-        assert 'ix_events_location' in index_names
-        assert 'sport_announcer' in index_names
-        assert 'idx_winners' in index_names
-        assert len(index_names) == 4
-
-        capt = []
-        connection = testing.db.connect()
-        # TODO: hacky, put a real connection proxy in
-        ex = connection._Connection__execute_context
-        def proxy(context):
-            capt.append(context.statement)
-            capt.append(repr(context.parameters))
-            ex(context)
-        connection._Connection__execute_context = proxy
-        schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection)
-        schemagen.traverse(events)
-
-        assert capt[0].strip().startswith('CREATE TABLE events')
-
-        s = set([capt[x].strip() for x in [2,4,6,8]])
-
-        assert s == set([
-            'CREATE UNIQUE INDEX ix_events_name ON events (name)',
-            'CREATE INDEX ix_events_location ON events (location)',
-            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)',
-            'CREATE INDEX idx_winners ON events (winner)'
-            ])
+        eq_(
+            set([ ix.name for ix in events.indexes ]),
+            set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners'])
+        )
+
+        self.assert_sql_execution(
+            testing.db,
+            lambda: events.create(testing.db),
+            RegexSQL("^CREATE TABLE events"),
+            AllOf(
+                ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'),
+                ExactSQL('CREATE INDEX ix_events_location ON events (location)'),
+                ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'),
+                ExactSQL('CREATE INDEX idx_winners ON events (winner)')
+            )
+        )
 
         # verify that the table is functional
         events.insert().execute(id=1, name='hockey finals', location='rink',
@@ -214,84 +185,57 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         dialect = testing.db.dialect.__class__()
         dialect.max_identifier_length = 20
 
-        schemagen = dialect.schemagenerator(dialect, None)
-        schemagen.execute = lambda : None
-
         t1 = Table("sometable", MetaData(), Column("foo", Integer))
-        schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
-        eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
-        schemagen.buffer.truncate(0)
-        schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo))
-        eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
-
-        schemadrop = dialect.schemadropper(dialect, None)
-        schemadrop.execute = lambda: None
-        assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
+        self.assert_compile(
+            schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)),
+            "CREATE INDEX this_name_is_t_1 ON sometable (foo)",
+            dialect=dialect
+        )
+        
+        self.assert_compile(
+            schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)),
+            "CREATE INDEX this_other_nam_1 ON sometable (foo)",
+            dialect=dialect
+        )
 
     
-class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
-    class accum(object):
-        def __init__(self):
-            self.statements = []
-        def __call__(self, sql, *a, **kw):
-            self.statements.append(sql)
-        def __contains__(self, substring):
-            for s in self.statements:
-                if substring in s:
-                    return True
-            return False
-        def __str__(self):
-            return '\n'.join([repr(x) for x in self.statements])
-        def clear(self):
-            del self.statements[:]
-
-    def setup(self):
-        self.sql = self.accum()
-        opts = config.db_opts.copy()
-        opts['strategy'] = 'mock'
-        opts['executor'] = self.sql
-        self.engine = engines.testing_engine(options=opts)
-
+class ConstraintCompilationTest(TestBase, AssertsCompiledSQL):
 
     def _test_deferrable(self, constraint_factory):
-        meta = MetaData(self.engine)
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=True))
-        t.create()
-        assert 'DEFERRABLE' in self.sql, self.sql
-        assert 'NOT DEFERRABLE' not in self.sql, self.sql
-        self.sql.clear()
-        meta.clear()
-
-        t = Table('tbl', meta,
+                  
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'DEFERRABLE' in sql, sql
+        assert 'NOT DEFERRABLE' not in sql, sql
+        
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=False))
-        t.create()
-        assert 'NOT DEFERRABLE' in self.sql
-        self.sql.clear()
-        meta.clear()
 
-        t = Table('tbl', meta,
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'NOT DEFERRABLE' in sql
+
+
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=True, initially='IMMEDIATE'))
-        t.create()
-        assert 'NOT DEFERRABLE' not in self.sql
-        assert 'INITIALLY IMMEDIATE' in self.sql
-        self.sql.clear()
-        meta.clear()
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'NOT DEFERRABLE' not in sql
+        assert 'INITIALLY IMMEDIATE' in sql
 
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=True, initially='DEFERRED'))
-        t.create()
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
 
-        assert 'NOT DEFERRABLE' not in self.sql
-        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+        assert 'NOT DEFERRABLE' not in sql
+        assert 'INITIALLY DEFERRED' in sql
 
     def test_deferrable_pk(self):
         factory = lambda **kw: PrimaryKeyConstraint('a', **kw)
@@ -302,15 +246,16 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
         self._test_deferrable(factory)
 
     def test_deferrable_column_fk(self):
-        meta = MetaData(self.engine)
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer,
                          ForeignKey('tbl.a', deferrable=True,
                                     initially='DEFERRED')))
-        t.create()
-        assert 'DEFERRABLE' in self.sql, self.sql
-        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE tbl (a INTEGER, b INTEGER, FOREIGN KEY(b) REFERENCES tbl (a) DEFERRABLE INITIALLY DEFERRED)",
+        )
 
     def test_deferrable_unique(self):
         factory = lambda **kw: UniqueConstraint('b', **kw)
@@ -321,15 +266,105 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
         self._test_deferrable(factory)
 
     def test_deferrable_column_check(self):
-        meta = MetaData(self.engine)
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer,
                          CheckConstraint('a < b',
                                          deferrable=True,
                                          initially='DEFERRED')))
-        t.create()
-        assert 'DEFERRABLE' in self.sql, self.sql
-        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+        
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE tbl (a INTEGER, b INTEGER  CHECK (a < b) DEFERRABLE INITIALLY DEFERRED)"
+        )
+    
+    def test_use_alter(self):
+        m = MetaData()
+        t = Table('t', m,
+                  Column('a', Integer),
+        )
+        
+        t2 = Table('t2', m,
+                Column('a', Integer, ForeignKey('t.a', use_alter=True, name='fk_ta')),
+                Column('b', Integer, ForeignKey('t.a', name='fk_tb')), # to ensure create ordering ...
+        )
+
+        e = engines.mock_engine(dialect_name='postgresql')
+        m.create_all(e)
+        m.drop_all(e)
+
+        e.assert_sql([
+            'CREATE TABLE t (a INTEGER)', 
+            'CREATE TABLE t2 (a INTEGER, b INTEGER, CONSTRAINT fk_tb FOREIGN KEY(b) REFERENCES t (a))', 
+            'ALTER TABLE t2 ADD CONSTRAINT fk_ta FOREIGN KEY(a) REFERENCES t (a)', 
+            'ALTER TABLE t2 DROP CONSTRAINT fk_ta', 
+            'DROP TABLE t2', 
+            'DROP TABLE t'
+        ])
+        
+        
+    def test_add_drop_constraint(self):
+        m = MetaData()
+        
+        t = Table('tbl', m,
+                  Column('a', Integer),
+                  Column('b', Integer)
+        )
+        
+        t2 = Table('t2', m,
+                Column('a', Integer),
+                Column('b', Integer)
+        )
+        
+        constraint = CheckConstraint('a < b',name="my_test_constraint", deferrable=True,initially='DEFERRED', table=t)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint  CHECK (a < b) DEFERRABLE INITIALLY DEFERRED"
+        )
+
+        self.assert_compile(
+            schema.DropConstraint(constraint),
+            "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint"
+        )
+
+        self.assert_compile(
+            schema.DropConstraint(constraint, cascade=True),
+            "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint CASCADE"
+        )
 
+        constraint = ForeignKeyConstraint(["b"], ["t2.a"])
+        t.append_constraint(constraint)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)"
+        )
 
+        constraint = ForeignKeyConstraint([t.c.a], [t2.c.b])
+        t.append_constraint(constraint)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)"
+        )
+
+        constraint = UniqueConstraint("a", "b", name="uq_cst")
+        t2.append_constraint(constraint)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD CONSTRAINT uq_cst  UNIQUE (a, b)"
+        )
+        
+        constraint = UniqueConstraint(t2.c.a, t2.c.b, name="uq_cs2")
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD CONSTRAINT uq_cs2  UNIQUE (a, b)"
+        )
+        
+        assert t.c.a.primary_key is False
+        constraint = PrimaryKeyConstraint(t.c.a)
+        assert t.c.a.primary_key is True
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD PRIMARY KEY (a)"
+        )
+    
+        
index 96415746650108e921a3fb5ed32142734d705d53..5638dad77f7c1df55d904f4b3d824263de6acf59 100644 (file)
@@ -3,7 +3,7 @@ import datetime
 from sqlalchemy import Sequence, Column, func
 from sqlalchemy.sql import select, text
 import sqlalchemy as sa
-from sqlalchemy.test import testing
+from sqlalchemy.test import testing, engines
 from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.testing import eq_
@@ -37,7 +37,7 @@ class DefaultTest(testing.TestBase):
                 # since its a "branched" connection
                 conn.close()
 
-        use_function_defaults = testing.against('postgres', 'mssql', 'maxdb')
+        use_function_defaults = testing.against('postgresql', 'mssql', 'maxdb')
         is_oracle = testing.against('oracle')
 
         # select "count(1)" returns different results on different DBs also
@@ -146,7 +146,7 @@ class DefaultTest(testing.TestBase):
             assert_raises_message(sa.exc.ArgumentError,
                                      ex_msg,
                                      sa.ColumnDefault, fn)
-
+    
     def test_arg_signature(self):
         def fn1(): pass
         def fn2(): pass
@@ -276,7 +276,7 @@ class DefaultTest(testing.TestBase):
         assert r.lastrow_has_defaults()
         eq_(set(r.context.postfetch_cols),
             set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]))
-
+        
         eq_(t.select(t.c.col1==54).execute().fetchall(),
             [(54, 'imthedefault', f, ts, ts, ctexec, True, False,
               12, today, None)])
@@ -284,7 +284,7 @@ class DefaultTest(testing.TestBase):
     @testing.fails_on('firebird', 'Data type unknown')
     def test_insertmany(self):
         # MySQL-Python 1.2.2 breaks functions in execute_many :(
-        if (testing.against('mysql') and
+        if (testing.against('mysql') and not testing.against('+zxjdbc') and
             testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)):
             return
 
@@ -304,12 +304,12 @@ class DefaultTest(testing.TestBase):
     def test_insert_values(self):
         t.insert(values={'col3':50}).execute()
         l = t.select().execute()
-        eq_(50, l.fetchone()['col3'])
+        eq_(50, l.first()['col3'])
 
     @testing.fails_on('firebird', 'Data type unknown')
     def test_updatemany(self):
         # MySQL-Python 1.2.2 breaks functions in execute_many :(
-        if (testing.against('mysql') and
+        if (testing.against('mysql') and not testing.against('+zxjdbc') and
             testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)):
             return
 
@@ -337,11 +337,11 @@ class DefaultTest(testing.TestBase):
     @testing.fails_on('firebird', 'Data type unknown')
     def test_update(self):
         r = t.insert().execute()
-        pk = r.last_inserted_ids()[0]
+        pk = r.inserted_primary_key[0]
         t.update(t.c.col1==pk).execute(col4=None, col5=None)
         ctexec = currenttime.scalar()
         l = t.select(t.c.col1==pk).execute()
-        l = l.fetchone()
+        l = l.first()
         eq_(l,
             (pk, 'im the update', f2, None, None, ctexec, True, False,
              13, datetime.date.today(), 'py'))
@@ -350,43 +350,12 @@ class DefaultTest(testing.TestBase):
     @testing.fails_on('firebird', 'Data type unknown')
     def test_update_values(self):
         r = t.insert().execute()
-        pk = r.last_inserted_ids()[0]
+        pk = r.inserted_primary_key[0]
         t.update(t.c.col1==pk, values={'col3': 55}).execute()
         l = t.select(t.c.col1==pk).execute()
-        l = l.fetchone()
+        l = l.first()
         eq_(55, l['col3'])
 
-    @testing.fails_on_everything_except('postgres')
-    def test_passive_override(self):
-        """
-        Primarily for postgres, tests that when we get a primary key column
-        back from reflecting a table which has a default value on it, we
-        pre-execute that DefaultClause upon insert, even though DefaultClause
-        says "let the database execute this", because in postgres we must have
-        all the primary key values in memory before insert; otherwise we can't
-        locate the just inserted row.
-
-        """
-        # TODO: move this to dialect/postgres
-        try:
-            meta = MetaData(testing.db)
-            testing.db.execute("""
-             CREATE TABLE speedy_users
-             (
-                 speedy_user_id   SERIAL     PRIMARY KEY,
-
-                 user_name        VARCHAR    NOT NULL,
-                 user_password    VARCHAR    NOT NULL
-             );
-            """, None)
-
-            t = Table("speedy_users", meta, autoload=True)
-            t.insert().execute(user_name='user', user_password='lala')
-            l = t.select().execute().fetchall()
-            eq_(l, [(1, 'user', 'lala')])
-        finally:
-            testing.db.execute("drop table speedy_users", None)
-
 
 class PKDefaultTest(_base.TablesTest):
     __requires__ = ('subqueries',)
@@ -400,18 +369,27 @@ class PKDefaultTest(_base.TablesTest):
               Column('id', Integer, primary_key=True,
                      default=sa.select([func.max(t2.c.nextid)]).as_scalar()),
               Column('data', String(30)))
-
-    @testing.fails_on('mssql', 'FIXME: unknown')
+    
+    @testing.requires.returning
+    def test_with_implicit_returning(self):
+        self._test(True)
+        
+    def test_regular(self):
+        self._test(False)
+        
     @testing.resolve_artifact_names
-    def test_basic(self):
-        t2.insert().execute(nextid=1)
-        r = t1.insert().execute(data='hi')
-        eq_([1], r.last_inserted_ids())
-
-        t2.insert().execute(nextid=2)
-        r = t1.insert().execute(data='there')
-        eq_([2], r.last_inserted_ids())
+    def _test(self, returning):
+        if not returning and not testing.db.dialect.implicit_returning:
+            engine = testing.db
+        else:
+            engine = engines.testing_engine(options={'implicit_returning':returning})
+        engine.execute(t2.insert(), nextid=1)
+        r = engine.execute(t1.insert(), data='hi')
+        eq_([1], r.inserted_primary_key)
 
+        engine.execute(t2.insert(), nextid=2)
+        r = engine.execute(t1.insert(), data='there')
+        eq_([2], r.inserted_primary_key)
 
 class PKIncrementTest(_base.TablesTest):
     run_define_tables = 'each'
@@ -430,29 +408,31 @@ class PKIncrementTest(_base.TablesTest):
     def _test_autoincrement(self, bind):
         ids = set()
         rs = bind.execute(aitable.insert(), int1=1)
-        last = rs.last_inserted_ids()[0]
+        last = rs.inserted_primary_key[0]
         self.assert_(last)
         self.assert_(last not in ids)
         ids.add(last)
 
         rs = bind.execute(aitable.insert(), str1='row 2')
-        last = rs.last_inserted_ids()[0]
+        last = rs.inserted_primary_key[0]
         self.assert_(last)
         self.assert_(last not in ids)
         ids.add(last)
 
         rs = bind.execute(aitable.insert(), int1=3, str1='row 3')
-        last = rs.last_inserted_ids()[0]
+        last = rs.inserted_primary_key[0]
         self.assert_(last)
         self.assert_(last not in ids)
         ids.add(last)
 
         rs = bind.execute(aitable.insert(values={'int1':func.length('four')}))
-        last = rs.last_inserted_ids()[0]
+        last = rs.inserted_primary_key[0]
         self.assert_(last)
         self.assert_(last not in ids)
         ids.add(last)
 
+        eq_(ids, set([1,2,3,4]))
+        
         eq_(list(bind.execute(aitable.select().order_by(aitable.c.id))),
             [(1, 1, None), (2, None, 'row 2'), (3, 3, 'row 3'), (4, 4, None)])
 
@@ -510,8 +490,8 @@ class AutoIncrementTest(_base.TablesTest):
         single.create()
 
         r = single.insert().execute()
-        id_ = r.last_inserted_ids()[0]
-        assert id_ is not None
+        id_ = r.inserted_primary_key[0]
+        eq_(id_, 1)
         eq_(1, sa.select([func.count(sa.text('*'))], from_obj=single).scalar())
 
     def test_autoincrement_fk(self):
@@ -522,7 +502,7 @@ class AutoIncrementTest(_base.TablesTest):
         nodes.create()
 
         r = nodes.insert().execute(data='foo')
-        id_ = r.last_inserted_ids()[0]
+        id_ = r.inserted_primary_key[0]
         nodes.insert().execute(data='bar', parent_id=id_)
 
     @testing.fails_on('sqlite', 'FIXME: unknown')
@@ -535,7 +515,7 @@ class AutoIncrementTest(_base.TablesTest):
 
 
         try:
-            # postgres + mysql strict will fail on first row,
+            # postgresql + mysql strict will fail on first row,
             # mysql in legacy mode fails on second row
             nonai.insert().execute(data='row 1')
             nonai.insert().execute(data='row 2')
@@ -570,16 +550,17 @@ class SequenceTest(testing.TestBase):
     def testseqnonpk(self):
         """test sequences fire off as defaults on non-pk columns"""
 
-        result = sometable.insert().execute(name="somename")
+        engine = engines.testing_engine(options={'implicit_returning':False})
+        result = engine.execute(sometable.insert(), name="somename")
         assert 'id' in result.postfetch_cols()
 
-        result = sometable.insert().execute(name="someother")
+        result = engine.execute(sometable.insert(), name="someother")
         assert 'id' in result.postfetch_cols()
 
         sometable.insert().execute(
             {'name':'name3'},
             {'name':'name4'})
-        eq_(sometable.select().execute().fetchall(),
+        eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(),
             [(1, "somename", 1),
              (2, "someother", 2),
              (3, "name3", 3),
@@ -590,8 +571,8 @@ class SequenceTest(testing.TestBase):
         cartitems.insert().execute(description='there')
         r = cartitems.insert().execute(description='lala')
 
-        assert r.last_inserted_ids() and r.last_inserted_ids()[0] is not None
-        id_ = r.last_inserted_ids()[0]
+        assert r.inserted_primary_key and r.inserted_primary_key[0] is not None
+        id_ = r.inserted_primary_key[0]
 
         eq_(1,
             sa.select([func.count(cartitems.c.cart_id)],
index e9bf49ce30af1016896f633652f5a65f8a2f1eef..7a0f12cac30a83d61749211be2245d98aa5e5fef 100644 (file)
@@ -24,7 +24,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
             self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect)
             self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
-            if isinstance(dialect, firebird.dialect):
+            if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)):
                 self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect)
             else:
                 self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect)
@@ -50,7 +50,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
 
         for ret, dialect in [
             ('CURRENT_TIMESTAMP', sqlite.dialect()),
-            ('now()', postgres.dialect()),
+            ('now()', postgresql.dialect()),
             ('now()', mysql.dialect()),
             ('CURRENT_TIMESTAMP', oracle.dialect())
         ]:
@@ -62,9 +62,9 @@ class CompileTest(TestBase, AssertsCompiledSQL):
 
         for ret, dialect in [
             ('random()', sqlite.dialect()),
-            ('random()', postgres.dialect()),
+            ('random()', postgresql.dialect()),
             ('rand()', mysql.dialect()),
-            ('random()', oracle.dialect())
+            ('random', oracle.dialect())
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)
 
@@ -180,7 +180,10 @@ class CompileTest(TestBase, AssertsCompiledSQL):
 
 
 class ExecuteTest(TestBase):
-
+    @engines.close_first
+    def tearDown(self):
+        pass
+        
     def test_standalone_execute(self):
         x = testing.db.func.current_date().execute().scalar()
         y = testing.db.func.current_date().select().execute().scalar()
@@ -202,6 +205,7 @@ class ExecuteTest(TestBase):
             conn.close()
         assert (x == y == z) is True
 
+    @engines.close_first
     def test_update(self):
         """
         Tests sending functions and SQL expressions to the VALUES and SET
@@ -222,15 +226,15 @@ class ExecuteTest(TestBase):
         meta.create_all()
         try:
             t.insert(values=dict(value=func.length("one"))).execute()
-            assert t.select().execute().fetchone()['value'] == 3
+            assert t.select().execute().first()['value'] == 3
             t.update(values=dict(value=func.length("asfda"))).execute()
-            assert t.select().execute().fetchone()['value'] == 5
+            assert t.select().execute().first()['value'] == 5
 
             r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
-            id = r.last_inserted_ids()[0]
-            assert t.select(t.c.id==id).execute().fetchone()['value'] == 9
+            id = r.inserted_primary_key[0]
+            assert t.select(t.c.id==id).execute().first()['value'] == 9
             t.update(values={t.c.value:func.length("asdf")}).execute()
-            assert t.select().execute().fetchone()['value'] == 4
+            assert t.select().execute().first()['value'] == 4
             print "--------------------------"
             t2.insert().execute()
             t2.insert(values=dict(value=func.length("one"))).execute()
@@ -245,18 +249,18 @@ class ExecuteTest(TestBase):
             t2.delete().execute()
 
             t2.insert(values=dict(value=func.length("one") + 8)).execute()
-            assert t2.select().execute().fetchone()['value'] == 11
+            assert t2.select().execute().first()['value'] == 11
 
             t2.update(values=dict(value=func.length("asfda"))).execute()
-            assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff")
+            assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff")
 
             t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute()
-            print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone()
-            assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo")
+            print "HI", select([t2.c.value, t2.c.stuff]).execute().first()
+            assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo")
         finally:
             meta.drop_all()
 
-    @testing.fails_on_everything_except('postgres')
+    @testing.fails_on_everything_except('postgresql')
     def test_as_from(self):
         # TODO: shouldnt this work on oracle too ?
         x = testing.db.func.current_date().execute().scalar()
@@ -266,7 +270,7 @@ class ExecuteTest(TestBase):
 
         # construct a column-based FROM object out of a function, like in [ticket:172]
         s = select([sql.column('date', type_=DateTime)], from_obj=[testing.db.func.current_date()])
-        q = s.execute().fetchone()[s.c.date]
+        q = s.execute().first()[s.c.date]
         r = s.alias('datequery').select().scalar()
 
         assert x == y == z == w == q == r
@@ -301,7 +305,7 @@ class ExecuteTest(TestBase):
                  'd': datetime.date(2010, 5, 1) })
             rs = select([extract('year', table.c.dt),
                          extract('month', table.c.d)]).execute()
-            row = rs.fetchone()
+            row = rs.first()
             assert row[0] == 2010
             assert row[1] == 5
             rs.close()
index b946b0ae9885078a769ce3a17587487ddae291ea..bcac7c01d2acfadd077277262b3d12d46ecf55d7 100644 (file)
@@ -35,6 +35,7 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
         maxlen = testing.db.dialect.max_identifier_length
         testing.db.dialect.max_identifier_length = IDENT_LENGTH
 
+    @engines.close_first
     def teardown(self):
         table1.delete().execute()
 
@@ -92,10 +93,16 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
         ], repr(result)
 
     def test_table_alias_names(self):
-        self.assert_compile(
-            table2.alias().select(),
-            "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1"
-        )
+        if testing.against('oracle'):
+            self.assert_compile(
+                table2.alias().select(),
+                "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs table_with_exactly_29_c_1"
+            )
+        else:
+            self.assert_compile(
+                table2.alias().select(),
+                "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1"
+            )
 
         ta = table2.alias()
         dialect = default.DefaultDialect()
index 51b933e45857da943873e85e70964a1ba8e41eba..0e3b9dff209aa329fa1021b7ae2a9cb35d6a1423 100644 (file)
@@ -1,9 +1,11 @@
+from sqlalchemy.test.testing import eq_
 import datetime
 from sqlalchemy import *
 from sqlalchemy import exc, sql
 from sqlalchemy.engine import default
 from sqlalchemy.test import *
-from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.testing import eq_, assert_raises_message
+from sqlalchemy.test.schema import Table, Column
 
 class QueryTest(TestBase):
 
@@ -12,11 +14,11 @@ class QueryTest(TestBase):
         global users, users2, addresses, metadata
         metadata = MetaData(testing.db)
         users = Table('query_users', metadata,
-            Column('user_id', INT, primary_key = True),
+            Column('user_id', INT, primary_key=True, test_needs_autoincrement=True),
             Column('user_name', VARCHAR(20)),
         )
         addresses = Table('query_addresses', metadata,
-            Column('address_id', Integer, primary_key=True),
+            Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer, ForeignKey('query_users.user_id')),
             Column('address', String(30)))
             
@@ -26,7 +28,8 @@ class QueryTest(TestBase):
         )
         metadata.create_all()
 
-    def tearDown(self):
+    @engines.close_first
+    def teardown(self):
         addresses.delete().execute()
         users.delete().execute()
         users2.delete().execute()
@@ -52,89 +55,133 @@ class QueryTest(TestBase):
         assert users.count().scalar() == 1
 
         users.update(users.c.user_id == 7).execute(user_name = 'fred')
-        assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred'
+        assert users.select(users.c.user_id==7).execute().first()['user_name'] == 'fred'
 
     def test_lastrow_accessor(self):
-        """Tests the last_inserted_ids() and lastrow_has_id() functions."""
+        """Tests the inserted_primary_key and lastrow_has_id() functions."""
 
-        def insert_values(table, values):
+        def insert_values(engine, table, values):
             """
             Inserts a row into a table, returns the full list of values
             INSERTed including defaults that fired off on the DB side and
             detects rows that had defaults and post-fetches.
             """
 
-            result = table.insert().execute(**values)
+            result = engine.execute(table.insert(), **values)
             ret = values.copy()
             
-            for col, id in zip(table.primary_key, result.last_inserted_ids()):
+            for col, id in zip(table.primary_key, result.inserted_primary_key):
                 ret[col.key] = id
 
             if result.lastrow_has_defaults():
-                criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
-                row = table.select(criterion).execute().fetchone()
+                criterion = and_(*[col==id for col, id in zip(table.primary_key, result.inserted_primary_key)])
+                row = engine.execute(table.select(criterion)).first()
                 for c in table.c:
                     ret[c.key] = row[c]
             return ret
 
-        for supported, table, values, assertvalues in [
-            (
-                {'unsupported':['sqlite']},
-                Table("t1", metadata,
-                    Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
-                    Column('foo', String(30), primary_key=True)),
-                {'foo':'hi'},
-                {'id':1, 'foo':'hi'}
-            ),
-            (
-                {'unsupported':['sqlite']},
-                Table("t2", metadata,
-                    Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
-                    Column('foo', String(30), primary_key=True),
-                    Column('bar', String(30), server_default='hi')
+        if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
+            test_engines = [
+                engines.testing_engine(options={'implicit_returning':False}),
+                engines.testing_engine(options={'implicit_returning':True}),
+            ]
+        else:
+            test_engines = [testing.db]
+            
+        for engine in test_engines:
+            metadata = MetaData()
+            for supported, table, values, assertvalues in [
+                (
+                    {'unsupported':['sqlite']},
+                    Table("t1", metadata,
+                        Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+                        Column('foo', String(30), primary_key=True)),
+                    {'foo':'hi'},
+                    {'id':1, 'foo':'hi'}
                 ),
-                {'foo':'hi'},
-                {'id':1, 'foo':'hi', 'bar':'hi'}
-            ),
-            (
-                {'unsupported':[]},
-                Table("t3", metadata,
-                    Column("id", String(40), primary_key=True),
-                    Column('foo', String(30), primary_key=True),
-                    Column("bar", String(30))
+                (
+                    {'unsupported':['sqlite']},
+                    Table("t2", metadata,
+                        Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+                        Column('foo', String(30), primary_key=True),
+                        Column('bar', String(30), server_default='hi')
                     ),
-                    {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
-                    {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
-            ),
-            (
-                {'unsupported':[]},
-                Table("t4", metadata,
-                    Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
-                    Column('foo', String(30), primary_key=True),
-                    Column('bar', String(30), server_default='hi')
+                    {'foo':'hi'},
+                    {'id':1, 'foo':'hi', 'bar':'hi'}
                 ),
-                {'foo':'hi', 'id':1},
-                {'id':1, 'foo':'hi', 'bar':'hi'}
-            ),
-            (
-                {'unsupported':[]},
-                Table("t5", metadata,
-                    Column('id', String(10), primary_key=True),
-                    Column('bar', String(30), server_default='hi')
+                (
+                    {'unsupported':[]},
+                    Table("t3", metadata,
+                        Column("id", String(40), primary_key=True),
+                        Column('foo', String(30), primary_key=True),
+                        Column("bar", String(30))
+                        ),
+                        {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
+                        {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
                 ),
-                {'id':'id1'},
-                {'id':'id1', 'bar':'hi'},
-            ),
-        ]:
-            if testing.db.name in supported['unsupported']:
-                continue
-            try:
-                table.create()
-                i = insert_values(table, values)
-                assert i == assertvalues, repr(i) + " " + repr(assertvalues)
-            finally:
-                table.drop()
+                (
+                    {'unsupported':[]},
+                    Table("t4", metadata,
+                        Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
+                        Column('foo', String(30), primary_key=True),
+                        Column('bar', String(30), server_default='hi')
+                    ),
+                    {'foo':'hi', 'id':1},
+                    {'id':1, 'foo':'hi', 'bar':'hi'}
+                ),
+                (
+                    {'unsupported':[]},
+                    Table("t5", metadata,
+                        Column('id', String(10), primary_key=True),
+                        Column('bar', String(30), server_default='hi')
+                    ),
+                    {'id':'id1'},
+                    {'id':'id1', 'bar':'hi'},
+                ),
+            ]:
+                if testing.db.name in supported['unsupported']:
+                    continue
+                try:
+                    table.create(bind=engine, checkfirst=True)
+                    i = insert_values(engine, table, values)
+                    assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues))
+                finally:
+                    table.drop(bind=engine)
+
+    @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks")
+    def test_misordered_lastrow(self):
+        related = Table('related', metadata,
+            Column('id', Integer, primary_key=True)
+        )
+        t6 = Table("t6", metadata,
+            Column('manual_id', Integer, ForeignKey('related.id'), primary_key=True),
+            Column('auto_id', Integer, primary_key=True, test_needs_autoincrement=True),
+        )
 
+        metadata.create_all()
+        r = related.insert().values(id=12).execute()
+        id = r.inserted_primary_key[0]
+        assert id==12
+
+        r = t6.insert().values(manual_id=id).execute()
+        eq_(r.inserted_primary_key, [12, 1])
+
+    def test_autoclose_on_insert(self):
+        if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
+            test_engines = [
+                engines.testing_engine(options={'implicit_returning':False}),
+                engines.testing_engine(options={'implicit_returning':True}),
+            ]
+        else:
+            test_engines = [testing.db]
+            
+        for engine in test_engines:
+        
+            r = engine.execute(users.insert(), 
+                {'user_name':'jack'},
+            )
+            assert r.closed
+        
     def test_row_iteration(self):
         users.insert().execute(
             {'user_id':7, 'user_name':'jack'},
@@ -147,7 +194,7 @@ class QueryTest(TestBase):
             l.append(row)
         self.assert_(len(l) == 3)
 
-    @testing.fails_on('firebird', 'Data type unknown')
+    @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
     @testing.requires.subqueries
     def test_anonymous_rows(self):
         users.insert().execute(
@@ -161,6 +208,7 @@ class QueryTest(TestBase):
             assert row['anon_1'] == 8
             assert row['anon_2'] == 10
 
+    @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
     def test_order_by_label(self):
         """test that a label within an ORDER BY works on each backend.
         
@@ -179,6 +227,11 @@ class QueryTest(TestBase):
             select([concat]).order_by(concat).execute().fetchall(),
             [("test: ed",), ("test: fred",), ("test: jack",)]
         )
+        
+        eq_(
+            select([concat]).order_by(concat).execute().fetchall(),
+            [("test: ed",), ("test: fred",), ("test: jack",)]
+        )
 
         concat = ("test: " + users.c.user_name).label('thedata')
         eq_(
@@ -195,7 +248,7 @@ class QueryTest(TestBase):
         
     def test_row_comparison(self):
         users.insert().execute(user_id = 7, user_name = 'jack')
-        rp = users.select().execute().fetchone()
+        rp = users.select().execute().first()
 
         self.assert_(rp == rp)
         self.assert_(not(rp != rp))
@@ -207,8 +260,7 @@ class QueryTest(TestBase):
         self.assert_(not (rp != equal))
         self.assert_(not (equal != equal))
 
-    @testing.fails_on('mssql', 'No support for boolean logic in column select.')
-    @testing.fails_on('oracle', 'FIXME: unknown')
+    @testing.requires.boolean_col_expressions
     def test_or_and_as_columns(self):
         true, false = literal(True), literal(False)
         
@@ -218,11 +270,11 @@ class QueryTest(TestBase):
         eq_(testing.db.execute(select([or_(false, false)])).scalar(), False)
         eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
 
-        row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone()
+        row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).first()
         assert row.x == False
         assert row.y == False
 
-        row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).fetchone()
+        row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).first()
         assert row.x == True
         assert row.y == False
         
@@ -253,6 +305,9 @@ class QueryTest(TestBase):
             eq_(expr.execute().fetchall(), result)
     
 
+    @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text")
+    @testing.fails_on("oracle", "neither % nor %% are accepted")
+    @testing.fails_on("+pg8000", "can't interpret result column from '%%'")
     @testing.emits_warning('.*now automatically escapes.*')
     def test_percents_in_text(self):
         for expr, result in (
@@ -277,7 +332,7 @@ class QueryTest(TestBase):
 
         eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
 
-        if testing.against('postgres'):
+        if testing.against('postgresql'):
             eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
             eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
 
@@ -373,7 +428,7 @@ class QueryTest(TestBase):
             s = select([datetable.alias('x').c.today]).as_scalar()
             s2 = select([datetable.c.id, s.label('somelabel')])
             #print s2.c.somelabel.type
-            assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime)
+            assert isinstance(s2.execute().first()['somelabel'], datetime.datetime)
         finally:
             datetable.drop()
 
@@ -444,45 +499,58 @@ class QueryTest(TestBase):
         users.insert().execute(user_id=2, user_name='jack')
         addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com')
 
-        r = users.select(users.c.user_id==2).execute().fetchone()
+        r = users.select(users.c.user_id==2).execute().first()
         self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
         self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
-
-        r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone()
+        
+        r = text("select * from query_users where user_id=2", bind=testing.db).execute().first()
         self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
         self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
-
+        
         # test slices
-        r = text("select * from query_addresses", bind=testing.db).execute().fetchone()
+        r = text("select * from query_addresses", bind=testing.db).execute().first()
         self.assert_(r[0:1] == (1,))
         self.assert_(r[1:] == (2, 'foo@bar.com'))
         self.assert_(r[:-1] == (1, 2))
-
+        
         # test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description
         r = text("select query_users.user_id, query_users.user_name from query_users "
-            "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone()
+            "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().first()
         self.assert_(r['user_id']) == 1
         self.assert_(r['user_name']) == "john"
 
         # test using literal tablename.colname
-        r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone()
+        r = text('select query_users.user_id AS "query_users.user_id", '
+                'query_users.user_name AS "query_users.user_name" from query_users', 
+                bind=testing.db).execute().first()
         self.assert_(r['query_users.user_id']) == 1
         self.assert_(r['query_users.user_name']) == "john"
 
         # unary experssions
-        r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().fetchone()
+        r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().first()
         eq_(r[users.c.user_name], 'jack')
         eq_(r.user_name, 'jack')
-        r.close()
+
+    def test_result_case_sensitivity(self):
+        """test name normalization for result sets."""
         
+        row = testing.db.execute(
+            select([
+                literal_column("1").label("case_insensitive"),
+                literal_column("2").label("CaseSensitive")
+            ])
+        ).first()
+        
+        assert row.keys() == ["case_insensitive", "CaseSensitive"]
+
         
     def test_row_as_args(self):
         users.insert().execute(user_id=1, user_name='john')
-        r = users.select(users.c.user_id==1).execute().fetchone()
+        r = users.select(users.c.user_id==1).execute().first()
         users.delete().execute()
         users.insert().execute(r)
-        assert users.select().execute().fetchall() == [(1, 'john')]
-    
+        eq_(users.select().execute().fetchall(), [(1, 'john')])
+
     def test_result_as_args(self):
         users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')])
         r = users.select().execute()
@@ -496,13 +564,12 @@ class QueryTest(TestBase):
         
     def test_ambiguous_column(self):
         users.insert().execute(user_id=1, user_name='john')
-        r = users.outerjoin(addresses).select().execute().fetchone()
-        try:
-            print r['user_id']
-            assert False
-        except exc.InvalidRequestError, e:
-            assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
-                   str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
+        r = users.outerjoin(addresses).select().execute().first()
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            lambda: r['user_id']
+        )
 
     @testing.requires.subqueries
     def test_column_label_targeting(self):
@@ -512,31 +579,29 @@ class QueryTest(TestBase):
             users.select().alias('foo'),
             users.select().alias(users.name),
         ):
-            row = s.select(use_labels=True).execute().fetchone()
+            row = s.select(use_labels=True).execute().first()
             assert row[s.c.user_id] == 7
             assert row[s.c.user_name] == 'ed'
 
     def test_keys(self):
         users.insert().execute(user_id=1, user_name='foo')
-        r = users.select().execute().fetchone()
+        r = users.select().execute().first()
         eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
 
     def test_items(self):
         users.insert().execute(user_id=1, user_name='foo')
-        r = users.select().execute().fetchone()
+        r = users.select().execute().first()
         eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
 
     def test_len(self):
         users.insert().execute(user_id=1, user_name='foo')
-        r = users.select().execute().fetchone()
+        r = users.select().execute().first()
         eq_(len(r), 2)
-        r.close()
-        r = testing.db.execute('select user_name, user_id from query_users').fetchone()
+            
+        r = testing.db.execute('select user_name, user_id from query_users').first()
         eq_(len(r), 2)
-        r.close()
-        r = testing.db.execute('select user_name from query_users').fetchone()
+        r = testing.db.execute('select user_name from query_users').first()
         eq_(len(r), 1)
-        r.close()
 
     def test_cant_execute_join(self):
         try:
@@ -549,7 +614,7 @@ class QueryTest(TestBase):
     def test_column_order_with_simple_query(self):
         # should return values in column definition order
         users.insert().execute(user_id=1, user_name='foo')
-        r = users.select(users.c.user_id==1).execute().fetchone()
+        r = users.select(users.c.user_id==1).execute().first()
         eq_(r[0], 1)
         eq_(r[1], 'foo')
         eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
@@ -558,7 +623,7 @@ class QueryTest(TestBase):
     def test_column_order_with_text_query(self):
         # should return values in query order
         users.insert().execute(user_id=1, user_name='foo')
-        r = testing.db.execute('select user_name, user_id from query_users').fetchone()
+        r = testing.db.execute('select user_name, user_id from query_users').first()
         eq_(r[0], 'foo')
         eq_(r[1], 1)
         eq_([x.lower() for x in r.keys()], ['user_name', 'user_id'])
@@ -580,7 +645,7 @@ class QueryTest(TestBase):
         shadowed.create(checkfirst=True)
         try:
             shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
-            r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone()
+            r = shadowed.select(shadowed.c.shadow_id==1).execute().first()
             self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1)
             self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow')
             self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light')
@@ -622,13 +687,13 @@ class QueryTest(TestBase):
         # Null values are not outside any set
         assert len(r) == 0
 
-        u = bindparam('search_key')
+    @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
+    def test_bind_in(self):
+        users.insert().execute(user_id = 7, user_name = 'jack')
+        users.insert().execute(user_id = 8, user_name = 'fred')
+        users.insert().execute(user_id = 9, user_name = None)
 
-        s = users.select(u.in_([]))
-        r = s.execute(search_key='john').fetchall()
-        assert len(r) == 0
-        r = s.execute(search_key=None).fetchall()
-        assert len(r) == 0
+        u = bindparam('search_key')
 
         s = users.select(not_(u.in_([])))
         r = s.execute(search_key='john').fetchall()
@@ -660,14 +725,15 @@ class QueryTest(TestBase):
 class PercentSchemaNamesTest(TestBase):
     """tests using percent signs, spaces in table and column names.
     
-    Doesn't pass for mysql, postgres, but this is really a 
+    Doesn't pass for mysql, postgresql, but this is really a 
     SQLAlchemy bug - we should be escaping out %% signs for this
     operation the same way we do for text() and column labels.
     
     """
+
     @classmethod
     @testing.crashes('mysql', 'mysqldb calls name % (params)')
-    @testing.crashes('postgres', 'postgres calls name % (params)')
+    @testing.crashes('postgresql', 'postgresql calls name % (params)')
     def setup_class(cls):
         global percent_table, metadata
         metadata = MetaData(testing.db)
@@ -680,12 +746,12 @@ class PercentSchemaNamesTest(TestBase):
 
     @classmethod
     @testing.crashes('mysql', 'mysqldb calls name % (params)')
-    @testing.crashes('postgres', 'postgres calls name % (params)')
+    @testing.crashes('postgresql', 'postgresql calls name % (params)')
     def teardown_class(cls):
         metadata.drop_all()
     
     @testing.crashes('mysql', 'mysqldb calls name % (params)')
-    @testing.crashes('postgres', 'postgres calls name % (params)')
+    @testing.crashes('postgresql', 'postgresql calls name % (params)')
     def test_roundtrip(self):
         percent_table.insert().execute(
             {'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12},
@@ -731,7 +797,7 @@ class PercentSchemaNamesTest(TestBase):
         percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute()
 
         eq_(
-            percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
+            percent_table.select().order_by(percent_table.c['percent%']).execute().fetchall(),
             [
                 (5, 9, 15),
                 (7, 9, 15),
@@ -852,7 +918,11 @@ class CompoundTest(TestBase):
             dict(col2="t3col2r2", col3="bbb", col4="aaa"),
             dict(col2="t3col2r3", col3="ccc", col4="bbb"),
         ])
-
+        
+    @engines.close_first
+    def teardown(self):
+        pass
+        
     @classmethod
     def teardown_class(cls):
         metadata.drop_all()
@@ -878,6 +948,7 @@ class CompoundTest(TestBase):
         found2 = self._fetchall_sorted(u.alias('bar').select().execute())
         eq_(found2, wanted)
 
+    @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs")
     def test_union_ordered(self):
         (s1, s2) = (
             select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
@@ -891,6 +962,7 @@ class CompoundTest(TestBase):
                   ('ccc', 'aaa')]
         eq_(u.execute().fetchall(), wanted)
 
+    @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs")
     @testing.fails_on('maxdb', 'FIXME: unknown')
     @testing.requires.subqueries
     def test_union_ordered_alias(self):
@@ -907,6 +979,7 @@ class CompoundTest(TestBase):
         eq_(u.alias('bar').select().execute().fetchall(), wanted)
 
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+    @testing.fails_on('firebird', "has trouble extracting anonymous column from union subquery")
     @testing.fails_on('mysql', 'FIXME: unknown')
     @testing.fails_on('sqlite', 'FIXME: unknown')
     def test_union_all(self):
@@ -925,6 +998,29 @@ class CompoundTest(TestBase):
         found2 = self._fetchall_sorted(e.alias('foo').select().execute())
         eq_(found2, wanted)
 
+    def test_union_all_lightweight(self):
+        """like test_union_all, but breaks the sub-union into 
+        a subquery with an explicit column reference on the outside,
+        more palatable to a wider variety of engines.
+        
+        """
+        u = union(
+            select([t1.c.col3]),
+            select([t1.c.col3]),
+        ).alias()
+        
+        e = union_all(
+            select([t1.c.col3]),
+            select([u.c.col3])
+        )
+
+        wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
+        found1 = self._fetchall_sorted(e.execute())
+        eq_(found1, wanted)
+
+        found2 = self._fetchall_sorted(e.alias('foo').select().execute())
+        eq_(found2, wanted)
+
     @testing.crashes('firebird', 'Does not support intersect')
     @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
     @testing.fails_on('mysql', 'FIXME: unknown')
@@ -1330,3 +1426,6 @@ class OperatorTest(TestBase):
                    order_by=flds.c.idcol).execute().fetchall(),
             [(2,),(1,)]
         )
+
+
+
index 64e097b85fa266cec41914ef3b05650e20dc8ff9..3198a07af4350b9b6a0c80f95a7f5f4006df511c 100644 (file)
@@ -129,7 +129,7 @@ class QuoteTest(TestBase, AssertsCompiledSQL):
     def testlabels(self):
         """test the quoting of labels.
 
-        if labels arent quoted, a query in postgres in particular will fail since it produces:
+        if labels arent quoted, a query in postgresql in particular will fail since it produces:
 
         SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC"
         FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
new file mode 100644 (file)
index 0000000..e076f3f
--- /dev/null
@@ -0,0 +1,159 @@
+from sqlalchemy.test.testing import eq_
+from sqlalchemy import *
+from sqlalchemy.test import *
+from sqlalchemy.test.schema import Table, Column
+from sqlalchemy.types import TypeDecorator
+
+        
+class ReturningTest(TestBase, AssertsExecutionResults):
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
+
+    def setup(self):
+        meta = MetaData(testing.db)
+        global table, GoofyType
+        
+        class GoofyType(TypeDecorator):
+            impl = String
+            
+            def process_bind_param(self, value, dialect):
+                if value is None:
+                    return None
+                return "FOO" + value
+
+            def process_result_value(self, value, dialect):
+                if value is None:
+                    return None
+                return value + "BAR"
+            
+        table = Table('tables', meta,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('persons', Integer),
+            Column('full', Boolean),
+            Column('goofy', GoofyType(50))
+        )
+        table.create(checkfirst=True)
+    
+    def teardown(self):
+        table.drop()
+
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_column_targeting(self):
+        result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False})
+        
+        row = result.first()
+        assert row[table.c.id] == row['id'] == 1
+        assert row[table.c.full] == row['full'] == False
+        
+        result = table.insert().values(persons=5, full=True, goofy="somegoofy").\
+                            returning(table.c.persons, table.c.full, table.c.goofy).execute()
+        row = result.first()
+        assert row[table.c.persons] == row['persons'] == 5
+        assert row[table.c.full] == row['full'] == True
+        assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR"
+    
+    @testing.fails_on('firebird', "fb can't handle returning x AS y")
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_labeling(self):
+        result = table.insert().values(persons=6).\
+                            returning(table.c.persons.label('lala')).execute()
+        row = result.first()
+        assert row['lala'] == 6
+
+    @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params")
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_anon_expressions(self):
+        result = table.insert().values(goofy="someOTHERgoofy").\
+                            returning(func.lower(table.c.goofy, type_=GoofyType)).execute()
+        row = result.first()
+        assert row[0] == "foosomeothergoofyBAR"
+
+        result = table.insert().values(persons=12).\
+                            returning(table.c.persons + 18).execute()
+        row = result.first()
+        assert row[0] == 30
+        
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_update_returning(self):
+        table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+        result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute()
+        eq_(result.fetchall(), [(1,)])
+
+        result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        eq_(result2.fetchall(), [(1,True),(2,False)])
+
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_insert_returning(self):
+        result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False})
+
+        eq_(result.fetchall(), [(1,)])
+
+        @testing.fails_on('postgresql', '')
+        @testing.fails_on('oracle', '')
+        def test_executemany():
+            # return value is documented as failing with psycopg2/executemany
+            result2 = table.insert().returning(table).execute(
+                 [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
+            
+            if testing.against('firebird', 'mssql'):
+                # Multiple inserts only return the last row
+                eq_(result2.fetchall(), [(3,3,True, None)])
+            else:
+                # nobody does this as far as we know (pg8000?)
+                eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)])
+
+        test_executemany()
+
+        result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False})
+        eq_([dict(row) for row in result3], [{'id': 4}])
+    
+        
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    @testing.fails_on_everything_except('postgresql', 'firebird')
+    def test_literal_returning(self):
+        if testing.against("postgresql"):
+            literal_true = "true"
+        else:
+            literal_true = "1"
+
+        result4 = testing.db.execute('insert into tables (id, persons, "full") '
+                                        'values (5, 10, %s) returning persons' % literal_true)
+        eq_([dict(row) for row in result4], [{'persons': 10}])
+
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_delete_returning(self):
+        table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+        result = table.delete(table.c.persons > 4).returning(table.c.id).execute()
+        eq_(result.fetchall(), [(1,)])
+
+        result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        eq_(result2.fetchall(), [(2,False),])
+
+class SequenceReturningTest(TestBase):
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql')
+
+    def setup(self):
+        meta = MetaData(testing.db)
+        global table, seq
+        seq = Sequence('tid_seq')
+        table = Table('tables', meta,
+                    Column('id', Integer, seq, primary_key=True),
+                    Column('data', String(50))
+                )
+        table.create(checkfirst=True)
+
+    def teardown(self):
+        table.drop()
+
+    def test_insert(self):
+        r = table.insert().values(data='hi').returning(table.c.id).execute()
+        assert r.first() == (1, )
+        assert seq.execute() == 2
index f70492fb31e949106c466da72a5708a99fefd436..9acc94eb28179548c8b2727796052d38e885a3fa 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy import exc, sql, util
 from sqlalchemy.sql import table, column, label, compiler
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.engine import default
-from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
+from sqlalchemy.databases import *
 from sqlalchemy.test import *
 
 table1 = table('mytable',
@@ -149,12 +149,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
 
         self.assert_compile(
-            select([cast("data", sqlite.SLInteger)], use_labels=True),      # this will work with plain Integer in 0.6
+            select([cast("data", Integer)], use_labels=True),      # this will work with plain Integer in 0.6
             "SELECT CAST(:param_1 AS INTEGER) AS anon_1"
         )
         
-        
-        
     def test_nested_uselabels(self):
         """test nested anonymous label generation.  this
         essentially tests the ANONYMOUS_LABEL regex.
@@ -429,7 +427,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
 
     def test_operators(self):
         for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
-                                (operator.sub, '-'), (operator.div, '/'),
+                                (operator.sub, '-'), 
+                                # Py3K
+                                #(operator.truediv, '/'),
+                                # Py2K
+                                (operator.div, '/'),
+                                # end Py2K
                                 ):
             for (lhs, rhs, res) in (
                 (5, table1.c.myid, ':myid_1 %s mytable.myid'),
@@ -519,22 +522,22 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             (~table1.c.myid.like('somstr', escape='\\'), "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'", None),
             (table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'", None),
             (~table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'", None),
-            (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()),
-            (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()),
+            (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()),
+            (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()),
             (table1.c.name.ilike('%something%'), "lower(mytable.name) LIKE lower(:name_1)", None),
-            (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgres.PGDialect()),
+            (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgresql.PGDialect()),
             (~table1.c.name.ilike('%something%'), "lower(mytable.name) NOT LIKE lower(:name_1)", None),
-            (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgres.PGDialect()),
+            (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgresql.PGDialect()),
         ]:
             self.assert_compile(expr, check, dialect=dialect)
     
     def test_match(self):
         for expr, check, dialect in [
             (table1.c.myid.match('somstr'), "mytable.myid MATCH ?", sqlite.SQLiteDialect()),
-            (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.MySQLDialect()),
-            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.MSSQLDialect()),
-            (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgres.PGDialect()),
-            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.OracleDialect()),            
+            (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.dialect()),
+            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.dialect()),
+            (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgresql.dialect()),
+            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.dialect()),            
         ]:
             self.assert_compile(expr, check, dialect=dialect)
         
@@ -635,7 +638,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             select([table1.alias('foo')])
             ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo")
 
-        for dialect in (firebird.dialect(), oracle.dialect()):
+        for dialect in (oracle.dialect(),):
             self.assert_compile(
                 select([table1.alias('foo')])
                 ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo"
@@ -748,7 +751,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
                 params={},
         )
 
-        dialect = postgres.dialect()
+        dialect = postgresql.dialect()
         self.assert_compile(
             text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]),
                 "select * from foo where lala=%(bar)s and hoho=%(whee)s",
@@ -1122,10 +1125,10 @@ UNION SELECT mytable.myid FROM mytable"
                 self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect())
                 nonpositional = stmt.compile()
                 positional = stmt.compile(dialect=sqlite.dialect())
-                pp = positional.get_params()
+                pp = positional.params
                 assert [pp[k] for k in positional.positiontup] == expected_default_params_list
-                assert nonpositional.get_params(**test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict)))
-                pp = positional.get_params(**test_param_dict)
+                assert nonpositional.construct_params(test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict)))
+                pp = positional.construct_params(test_param_dict)
                 assert [pp[k] for k in positional.positiontup] == expected_test_params_list
 
         # check that params() doesnt modify original statement
@@ -1144,7 +1147,7 @@ UNION SELECT mytable.myid FROM mytable"
             ":myid_1) AS anon_1 FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)")
         positional = s2.compile(dialect=sqlite.dialect())
 
-        pp = positional.get_params()
+        pp = positional.params
         assert [pp[k] for k in positional.positiontup] == [12, 12]
 
         # check that conflicts with "unique" params are caught
@@ -1163,11 +1166,11 @@ UNION SELECT mytable.myid FROM mytable"
         params = dict(('in%d' % i, i) for i in range(total_params))
         sql = 'text clause %s' % ', '.join(in_clause)
         t = text(sql)
-        assert len(t.bindparams) == total_params
+        eq_(len(t.bindparams), total_params)
         c = t.compile()
         pp = c.construct_params(params)
-        assert len(set(pp)) == total_params
-        assert len(set(pp.values())) == total_params
+        eq_(len(set(pp)), total_params, '%s %s' % (len(set(pp)), len(pp)))
+        eq_(len(set(pp.values())), total_params)
         
 
     def test_bind_as_col(self):
@@ -1291,28 +1294,28 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
             eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
             eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
             eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
-            eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+            eq_(str(cast(1234, Text).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
             eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
             # fixme: shoving all of this dialect-specific stuff in one test
             # is now officialy completely ridiculous AND non-obviously omits
             # coverage on other dialects.
             sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
             if isinstance(dialect, type(mysql.dialect())):
-                eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
+                eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL) AS anon_1 \nFROM casttest")
             else:
-                eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
+                eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC) AS anon_1 \nFROM casttest")
 
         # first test with PostgreSQL engine
-        check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s')
+        check_results(postgresql.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s')
 
         # then the Oracle engine
-        check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1')
+        check_results(oracle.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1')
 
         # then the sqlite engine
-        check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
+        check_results(sqlite.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
 
         # then the MySQL engine
-        check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s')
+        check_results(mysql.dialect(), ['DECIMAL', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s')
 
         self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
         self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
@@ -1360,7 +1363,6 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
         s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')])
         assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg']
 
-        from sqlalchemy.databases.sqlite import SLNumeric
         meta = MetaData()
         t1 = Table('mytable', meta, Column('col1', Integer))
         
@@ -1368,7 +1370,7 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
             (table1.c.name, 'name', 'mytable.name', None),
             (table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'),
             (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'),
-            (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'),
+            (cast(table1.c.name, Numeric), 'CAST(mytable.name AS NUMERIC)', 'CAST(mytable.name AS NUMERIC)', 'anon_1'),
             (t1.c.col1, 'col1', 'mytable.col1', None),
             (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '')
         ):
index b0501c913498fde138566bb4fdc9de2a3c0f66a9..95ca0d17bf21a2a824d0c3026c418d6bdf7832d8 100644 (file)
@@ -416,7 +416,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
             Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True),
         )
         
-       # this is essentially the union formed by the ORM's polymorphic_union function.
+        # this is essentially the union formed by the ORM's polymorphic_union function.
         # we define two versions with different ordering of selects.
 
         # the first selectable has the "real" column classified_page.magazine_page_id
@@ -432,7 +432,6 @@ class ReduceTest(TestBase, AssertsExecutionResults):
                 magazine_page_table.c.page_id, 
                 cast(null(), Integer).label('magazine_page_id')
             ]).select_from(page_table.join(magazine_page_table)),
-            
         ).alias('pjoin')
 
         eq_(
index 15799358a7094e1c2ea2882128cea93267624d1a..9c90549e29fd12acdd4eb661edba2b8d58d527f1 100644 (file)
+# coding: utf-8
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import decimal
 import datetime, os, re
 from sqlalchemy import *
-from sqlalchemy import exc, types, util
+from sqlalchemy import exc, types, util, schema
 from sqlalchemy.sql import operators
 from sqlalchemy.test.testing import eq_
 import sqlalchemy.engine.url as url
-from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
+from sqlalchemy.databases import *
+
 from sqlalchemy.test import *
 
 
 class AdaptTest(TestBase):
-    def testadapt(self):
-        e1 = url.URL('postgres').get_dialect()()
-        e2 = url.URL('mysql').get_dialect()()
-        e3 = url.URL('sqlite').get_dialect()()
-        e4 = url.URL('firebird').get_dialect()()
-
-        type = String(40)
-
-        t1 = type.dialect_impl(e1)
-        t2 = type.dialect_impl(e2)
-        t3 = type.dialect_impl(e3)
-        t4 = type.dialect_impl(e4)
-
-        impls = [t1, t2, t3, t4]
-        for i,ta in enumerate(impls):
-            for j,tb in enumerate(impls):
-                if i == j:
-                    assert ta == tb  # call me paranoid...  :)
+    def test_uppercase_rendering(self):
+        """Test that uppercase types from types.py always render as their type.
+        
+        As of SQLA 0.6, using an uppercase type means you want specifically that
+        type.  If the database in use doesn't support that DDL, it (the DB backend) 
+        should raise an error - it means you should be using a lowercased (genericized) type.
+        
+        """
+        
+        for dialect in [
+                oracle.dialect(), 
+                mysql.dialect(), 
+                postgresql.dialect(), 
+                sqlite.dialect(), 
+                sybase.dialect(), 
+                informix.dialect(), 
+                maxdb.dialect(), 
+                mssql.dialect()]: # TODO when dialects are complete:  engines.all_dialects():
+            for type_, expected in (
+                (FLOAT, "FLOAT"),
+                (NUMERIC, "NUMERIC"),
+                (DECIMAL, "DECIMAL"),
+                (INTEGER, "INTEGER"),
+                (SMALLINT, "SMALLINT"),
+                (TIMESTAMP, "TIMESTAMP"),
+                (DATETIME, "DATETIME"),
+                (DATE, "DATE"),
+                (TIME, "TIME"),
+                (CLOB, "CLOB"),
+                (VARCHAR, "VARCHAR"),
+                (NVARCHAR, ("NVARCHAR", "NATIONAL VARCHAR")),
+                (CHAR, "CHAR"),
+                (NCHAR, ("NCHAR", "NATIONAL CHAR")),
+                (BLOB, "BLOB"),
+                (BOOLEAN, ("BOOLEAN", "BOOL"))
+            ):
+                if isinstance(expected, str):
+                    expected = (expected, )
+                for exp in expected:
+                    compiled = type_().compile(dialect=dialect)
+                    if exp in compiled:
+                        break
                 else:
-                    assert ta != tb
-
-    def testmsnvarchar(self):
-        dialect = mssql.MSSQLDialect()
-        # run the test twice to ensure the caching step works too
-        for x in range(0, 1):
-            col = Column('', Unicode(length=10))
-            dialect_type = col.type.dialect_impl(dialect)
-            assert isinstance(dialect_type, mssql.MSNVarchar)
-            assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
-
-
-    def testoracletimestamp(self):
-        dialect = oracle.OracleDialect()
-        t1 = oracle.OracleTimestamp
-        t2 = oracle.OracleTimestamp()
-        t3 = types.TIMESTAMP
-        assert isinstance(dialect.type_descriptor(t1), oracle.OracleTimestamp)
-        assert isinstance(dialect.type_descriptor(t2), oracle.OracleTimestamp)
-        assert isinstance(dialect.type_descriptor(t3), oracle.OracleTimestamp)
-
-    def testmysqlbinary(self):
-        dialect = mysql.MySQLDialect()
-        t1 = mysql.MSVarBinary
-        t2 = mysql.MSVarBinary()
-        assert isinstance(dialect.type_descriptor(t1), mysql.MSVarBinary)
-        assert isinstance(dialect.type_descriptor(t2), mysql.MSVarBinary)
-
-    def teststringadapt(self):
-        """test that String with no size becomes TEXT, *all* others stay as varchar/String"""
-
-        oracle_dialect = oracle.OracleDialect()
-        mysql_dialect = mysql.MySQLDialect()
-        postgres_dialect = postgres.PGDialect()
-        firebird_dialect = firebird.FBDialect()
-
-        for dialect, start, test in [
-            (oracle_dialect, String(), oracle.OracleString),
-            (oracle_dialect, VARCHAR(), oracle.OracleString),
-            (oracle_dialect, String(50), oracle.OracleString),
-            (oracle_dialect, Unicode(), oracle.OracleString),
-            (oracle_dialect, UnicodeText(), oracle.OracleText),
-            (oracle_dialect, NCHAR(), oracle.OracleString),
-            (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw),
-            (mysql_dialect, String(), mysql.MSString),
-            (mysql_dialect, VARCHAR(), mysql.MSString),
-            (mysql_dialect, String(50), mysql.MSString),
-            (mysql_dialect, Unicode(), mysql.MSString),
-            (mysql_dialect, UnicodeText(), mysql.MSText),
-            (mysql_dialect, NCHAR(), mysql.MSNChar),
-            (postgres_dialect, String(), postgres.PGString),
-            (postgres_dialect, VARCHAR(), postgres.PGString),
-            (postgres_dialect, String(50), postgres.PGString),
-            (postgres_dialect, Unicode(), postgres.PGString),
-            (postgres_dialect, UnicodeText(), postgres.PGText),
-            (postgres_dialect, NCHAR(), postgres.PGString),
-            (firebird_dialect, String(), firebird.FBString),
-            (firebird_dialect, VARCHAR(), firebird.FBString),
-            (firebird_dialect, String(50), firebird.FBString),
-            (firebird_dialect, Unicode(), firebird.FBString),
-            (firebird_dialect, UnicodeText(), firebird.FBText),
-            (firebird_dialect, NCHAR(), firebird.FBString),
-        ]:
-            assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
-
-
+                    assert False, "%r matches none of %r for dialect %s" % (compiled, expected, dialect.name)
+            
 
 class UserDefinedTest(TestBase):
     """tests user-defined types."""
@@ -131,7 +93,7 @@ class UserDefinedTest(TestBase):
     def setup_class(cls):
         global users, metadata
 
-        class MyType(types.TypeEngine):
+        class MyType(types.UserDefinedType):
             def get_col_spec(self):
                 return "VARCHAR(100)"
             def bind_processor(self, dialect):
@@ -267,124 +229,105 @@ class ColumnsTest(TestBase, AssertsExecutionResults):
         for aCol in testTable.c:
             eq_(
                 expectedResults[aCol.name],
-                db.dialect.schemagenerator(db.dialect, db, None, None).\
+                db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\
                   get_column_specification(aCol))
 
 class UnicodeTest(TestBase, AssertsExecutionResults):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""
+
     @classmethod
     def setup_class(cls):
-        global unicode_table
+        global unicode_table, metadata
         metadata = MetaData(testing.db)
         unicode_table = Table('unicode_table', metadata,
             Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
             Column('unicode_varchar', Unicode(250)),
             Column('unicode_text', UnicodeText),
-            Column('plain_varchar', String(250))
             )
-        unicode_table.create()
+        metadata.create_all()
+        
     @classmethod
     def teardown_class(cls):
-        unicode_table.drop()
+        metadata.drop_all()
 
+    @engines.close_first
     def teardown(self):
         unicode_table.delete().execute()
 
     def test_round_trip(self):
-        assert unicode_table.c.unicode_varchar.type.length == 250
-        rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
-        unicodedata = rawdata.decode('utf-8')
-        if testing.against('sqlite'):
-            rawdata = "something"
-            
-        unicode_table.insert().execute(unicode_varchar=unicodedata,
-                                       unicode_text=unicodedata,
-                                       plain_varchar=rawdata)
-        x = unicode_table.select().execute().fetchone()
+        unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: Â« S’il vous plaît… dessine-moi un mouton! Â»"
+        
+        unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata)
+        
+        x = unicode_table.select().execute().first()
         self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
         self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
-        if isinstance(x['plain_varchar'], unicode):
-            # SQLLite and MSSQL return non-unicode data as unicode
-            self.assert_(testing.against('sqlite', 'mssql'))
-            if not testing.against('sqlite'):
-                self.assert_(x['plain_varchar'] == unicodedata)
-        else:
-            self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
 
-    def test_union(self):
-        """ensure compiler processing works for UNIONs"""
+    def test_round_trip_executemany(self):
+        # cx_oracle was producing different behavior for cursor.executemany()
+        # vs. cursor.execute()
+        
+        unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: Â« S’il vous plaît… dessine-moi un mouton! Â»"
 
-        rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
-        unicodedata = rawdata.decode('utf-8')
-        if testing.against('sqlite'):
-            rawdata = "something"
-        unicode_table.insert().execute(unicode_varchar=unicodedata,
-                                       unicode_text=unicodedata,
-                                       plain_varchar=rawdata)
-                                       
-        x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().fetchone()
+        unicode_table.insert().execute(
+                dict(unicode_varchar=unicodedata,unicode_text=unicodedata),
+                dict(unicode_varchar=unicodedata,unicode_text=unicodedata)
+        )
+
+        x = unicode_table.select().execute().first()
         self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
+        self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
 
-    def test_assertions(self):
-        try:
-            unicode_table.insert().execute(unicode_varchar='not unicode')
-            assert False
-        except exc.SAWarning, e:
-            assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e)
+    def test_union(self):
+        """ensure compiler processing works for UNIONs"""
 
-        unicode_engine = engines.utf8_engine(options={'convert_unicode':True,
-                                                      'assert_unicode':True})
-        try:
-            try:
-                unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode')
-                assert False
-            except exc.InvalidRequestError, e:
-                assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'"
-
-            @testing.emits_warning('.*non-unicode bind')
-            def warns():
-                # test that data still goes in if warning is emitted....
-                unicode_table.insert().execute(unicode_varchar='not unicode')
-                assert (select([unicode_table.c.unicode_varchar]).execute().fetchall() == [('not unicode', )])
-            warns()
+        unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: Â« S’il vous plaît… dessine-moi un mouton! Â»"
 
-        finally:
-            unicode_engine.dispose()
+        unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata)
+                                       
+        x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().first()
+        self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
 
-    @testing.fails_on('oracle', 'FIXME: unknown')
+    @testing.fails_on('oracle', 'oracle converts empty strings to a blank space')
     def test_blank_strings(self):
         unicode_table.insert().execute(unicode_varchar=u'')
         assert select([unicode_table.c.unicode_varchar]).scalar() == u''
 
-    def test_engine_parameter(self):
-        """tests engine-wide unicode conversion"""
-        prev_unicode = testing.db.engine.dialect.convert_unicode
-        prev_assert = testing.db.engine.dialect.assert_unicode
-        try:
-            testing.db.engine.dialect.convert_unicode = True
-            testing.db.engine.dialect.assert_unicode = False
-            rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
-            unicodedata = rawdata.decode('utf-8')
-            if testing.against('sqlite', 'mssql'):
-                rawdata = "something"
-            unicode_table.insert().execute(unicode_varchar=unicodedata,
-                                           unicode_text=unicodedata,
-                                           plain_varchar=rawdata)
-            x = unicode_table.select().execute().fetchone()
-            self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
-            self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
-            if not testing.against('sqlite', 'mssql'):
-                self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata)
-        finally:
-            testing.db.engine.dialect.convert_unicode = prev_unicode
-            testing.db.engine.dialect.convert_unicode = prev_assert
-
-    @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
-    @testing.fails_on('firebird', 'Data type unknown')
-    def test_length_function(self):
-        """checks the database correctly understands the length of a unicode string"""
-        teststr = u'aaa\x1234'
-        self.assert_(testing.db.func.length(teststr).scalar() == len(teststr))
+    def test_parameters(self):
+        """test the dialect convert_unicode parameters."""
+
+        unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: Â« S’il vous plaît… dessine-moi un mouton! Â»"
+
+        u = Unicode(assert_unicode=True)
+        uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect)
+        # Py3K
+        #assert_raises(exc.InvalidRequestError, uni, b'x')
+        # Py2K
+        assert_raises(exc.InvalidRequestError, uni, 'x')
+        # end Py2K
+
+        u = Unicode()
+        uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect)
+        # Py3K
+        #assert_raises(exc.SAWarning, uni, b'x')
+        # Py2K
+        assert_raises(exc.SAWarning, uni, 'x')
+        # end Py2K
+
+        unicode_engine = engines.utf8_engine(options={'convert_unicode':True,'assert_unicode':True})
+        unicode_engine.dialect.supports_unicode_binds = False
+        
+        s = String()
+        uni = s.dialect_impl(unicode_engine.dialect).bind_processor(unicode_engine.dialect)
+        # Py3K
+        #assert_raises(exc.InvalidRequestError, uni, b'x')
+        #assert isinstance(uni(unicodedata), bytes)
+        # Py2K
+        assert_raises(exc.InvalidRequestError, uni, 'x')
+        assert isinstance(uni(unicodedata), str)
+        # end Py2K
+        
+        assert uni(unicodedata) == unicodedata.encode('utf-8')
 
 class BinaryTest(TestBase, AssertsExecutionResults):
     __excluded_on__ = (
@@ -409,18 +352,19 @@ class BinaryTest(TestBase, AssertsExecutionResults):
                 return value
 
         binary_table = Table('binary_table', MetaData(testing.db),
-        Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
-        Column('data', Binary),
-        Column('data_slice', Binary(100)),
-        Column('misc', String(30)),
-        # construct PickleType with non-native pickle module, since cPickle uses relative module
-        # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
-        # to the 'types' module
-        Column('pickled', PickleType),
-        Column('mypickle', MyPickleType)
+            Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
+            Column('data', Binary),
+            Column('data_slice', Binary(100)),
+            Column('misc', String(30)),
+            # construct PickleType with non-native pickle module, since cPickle uses relative module
+            # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
+            # to the 'types' module
+            Column('pickled', PickleType),
+            Column('mypickle', MyPickleType)
         )
         binary_table.create()
 
+    @engines.close_first
     def teardown(self):
         binary_table.delete().execute()
 
@@ -428,42 +372,65 @@ class BinaryTest(TestBase, AssertsExecutionResults):
     def teardown_class(cls):
         binary_table.drop()
 
-    @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00')
-    def testbinary(self):
+    def test_round_trip(self):
         testobj1 = pickleable.Foo('im foo 1')
         testobj2 = pickleable.Foo('im foo 2')
         testobj3 = pickleable.Foo('im foo 3')
 
         stream1 =self.load_stream('binary_data_one.dat')
         stream2 =self.load_stream('binary_data_two.dat')
-        binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat',    data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3)
-        binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
-        binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None)
+        binary_table.insert().execute(
+                            primary_id=1, 
+                            misc='binary_data_one.dat', 
+                            data=stream1, 
+                            data_slice=stream1[0:100], 
+                            pickled=testobj1, 
+                            mypickle=testobj3)
+        binary_table.insert().execute(
+                            primary_id=2, 
+                            misc='binary_data_two.dat', 
+                            data=stream2, 
+                            data_slice=stream2[0:99], 
+                            pickled=testobj2)
+        binary_table.insert().execute(
+                            primary_id=3, 
+                            misc='binary_data_two.dat', 
+                            data=None, 
+                            data_slice=stream2[0:99], 
+                            pickled=None)
 
         for stmt in (
             binary_table.select(order_by=binary_table.c.primary_id),
             text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db)
         ):
+            eq_data = lambda x, y: eq_(list(x), list(y))
+            if util.jython:
+                _eq_data = eq_data
+                def eq_data(x, y):
+                    # Jython currently returns arrays
+                    from array import ArrayType
+                    if isinstance(y, ArrayType):
+                        return eq_(x, y.tostring())
+                    return _eq_data(x, y)
             l = stmt.execute().fetchall()
-            eq_(list(stream1), list(l[0]['data']))
-            eq_(list(stream1[0:100]), list(l[0]['data_slice']))
-            eq_(list(stream2), list(l[1]['data']))
+            eq_data(stream1, l[0]['data'])
+            eq_data(stream1[0:100], l[0]['data_slice'])
+            eq_data(stream2, l[1]['data'])
             eq_(testobj1, l[0]['pickled'])
             eq_(testobj2, l[1]['pickled'])
             eq_(testobj3.moredata, l[0]['mypickle'].moredata)
             eq_(l[0]['mypickle'].stuff, 'this is the right stuff')
 
-    def load_stream(self, name, len=12579):
+    def load_stream(self, name):
         f = os.path.join(os.path.dirname(__file__), "..", name)
-        # put a number less than the typical MySQL default BLOB size
-        return file(f).read(len)
+        return open(f, mode='rb').read()
 
 class ExpressionTest(TestBase, AssertsExecutionResults):
     @classmethod
     def setup_class(cls):
         global test_table, meta
 
-        class MyCustomType(types.TypeEngine):
+        class MyCustomType(types.UserDefinedType):
             def get_col_spec(self):
                 return "INT"
             def bind_processor(self, dialect):
@@ -547,7 +514,6 @@ class DateTest(TestBase, AssertsExecutionResults):
 
         db = testing.db
         if testing.against('oracle'):
-            import sqlalchemy.databases.oracle as oracle
             insert_data =  [
                     (7, 'jack',
                      datetime.datetime(2005, 11, 10, 0, 0),
@@ -576,7 +542,7 @@ class DateTest(TestBase, AssertsExecutionResults):
             time_micro = 999
 
             # Missing or poor microsecond support:
-            if testing.against('mssql', 'mysql', 'firebird'):
+            if testing.against('mssql', 'mysql', 'firebird', '+zxjdbc'):
                 datetime_micro, time_micro = 0, 0
             # No microseconds for TIME
             elif testing.against('maxdb'):
@@ -608,7 +574,7 @@ class DateTest(TestBase, AssertsExecutionResults):
                        Column('user_date', Date),
                        Column('user_time', Time)]
 
-        if testing.against('sqlite', 'postgres'):
+        if testing.against('sqlite', 'postgresql'):
             insert_data.append(
                 (11, 'historic',
                 datetime.datetime(1850, 11, 10, 11, 52, 35, datetime_micro),
@@ -676,8 +642,8 @@ class DateTest(TestBase, AssertsExecutionResults):
             t.drop(checkfirst=True)
 
 class StringTest(TestBase, AssertsExecutionResults):
-    @testing.fails_on('mysql', 'FIXME: unknown')
-    @testing.fails_on('oracle', 'FIXME: unknown')
+
+    @testing.requires.unbounded_varchar
     def test_nolength_string(self):
         metadata = MetaData(testing.db)
         foo = Table('foo', metadata, Column('one', String))
@@ -700,10 +666,10 @@ class NumericTest(TestBase, AssertsExecutionResults):
         metadata = MetaData(testing.db)
         numeric_table = Table('numeric_table', metadata,
             Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True),
-            Column('numericcol', Numeric(asdecimal=False)),
-            Column('floatcol', Float),
-            Column('ncasdec', Numeric),
-            Column('fcasdec', Float(asdecimal=True))
+            Column('numericcol', Numeric(precision=10, scale=2, asdecimal=False)),
+            Column('floatcol', Float(precision=10, )),
+            Column('ncasdec', Numeric(precision=10, scale=2)),
+            Column('fcasdec', Float(precision=10, asdecimal=True))
         )
         metadata.create_all()
 
@@ -711,6 +677,7 @@ class NumericTest(TestBase, AssertsExecutionResults):
     def teardown_class(cls):
         metadata.drop_all()
 
+    @engines.close_first
     def teardown(self):
         numeric_table.delete().execute()
 
@@ -719,6 +686,7 @@ class NumericTest(TestBase, AssertsExecutionResults):
         from decimal import Decimal
         numeric_table.insert().execute(
             numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75)
+            
         numeric_table.insert().execute(
             numericcol=Decimal("3.5"), floatcol=Decimal("5.6"),
             ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75"))
@@ -744,33 +712,6 @@ class NumericTest(TestBase, AssertsExecutionResults):
             assert isinstance(row['ncasdec'], decimal.Decimal)
             assert isinstance(row['fcasdec'], decimal.Decimal)
 
-    def test_length_deprecation(self):
-        assert_raises(exc.SADeprecationWarning, Numeric, length=8)
-        
-        @testing.uses_deprecated(".*is deprecated for Numeric")
-        def go():
-            n = Numeric(length=12)
-            assert n.scale == 12
-        go()
-        
-        n = Numeric(scale=12)
-        for dialect in engines.all_dialects():
-            n2 = dialect.type_descriptor(n)
-            eq_(n2.scale, 12, dialect.name)
-            
-            # test colspec generates successfully using 'scale'
-            assert n2.get_col_spec()
-            
-            # test constructor of the dialect-specific type
-            n3 = n2.__class__(scale=5)
-            eq_(n3.scale, 5, dialect.name)
-            
-            @testing.uses_deprecated(".*is deprecated for Numeric")
-            def go():
-                n3 = n2.__class__(length=6)
-                eq_(n3.scale, 6, dialect.name)
-            go()
-                
             
 class IntervalTest(TestBase, AssertsExecutionResults):
     @classmethod
@@ -783,6 +724,7 @@ class IntervalTest(TestBase, AssertsExecutionResults):
             )
         metadata.create_all()
 
+    @engines.close_first
     def teardown(self):
         interval_table.delete().execute()
 
@@ -790,14 +732,16 @@ class IntervalTest(TestBase, AssertsExecutionResults):
     def teardown_class(cls):
         metadata.drop_all()
 
+    @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type")
+    @testing.fails_on("postgresql+zxjdbc", "Not yet known how to pass values of the INTERVAL type")
     def test_roundtrip(self):
         delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17)
         interval_table.insert().execute(interval=delta)
-        assert interval_table.select().execute().fetchone()['interval'] == delta
+        assert interval_table.select().execute().first()['interval'] == delta
 
     def test_null(self):
         interval_table.insert().execute(id=1, inverval=None)
-        assert interval_table.select().execute().fetchone()['interval'] is None
+        assert interval_table.select().execute().first()['interval'] is None
 
 class BooleanTest(TestBase, AssertsExecutionResults):
     @classmethod
@@ -825,30 +769,6 @@ class BooleanTest(TestBase, AssertsExecutionResults):
         assert(res2==[(2, False)])
 
 class PickleTest(TestBase):
-    def test_noeq_deprecation(self):
-        p1 = PickleType()
-        
-        assert_raises(DeprecationWarning, 
-            p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)
-        )
-
-        assert_raises(DeprecationWarning, 
-            p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)
-        )
-        
-        @testing.uses_deprecated()
-        def go():
-            # test actual dumps comparison
-            assert p1.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2))
-            assert p1.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2))
-        go()
-        
-        assert p1.compare_values({1:2, 3:4}, {3:4, 1:2})
-        
-        p2 = PickleType(mutable=False)
-        assert not p2.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2))
-        assert not p2.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2))
-        
     def test_eq_comparison(self):
         p1 = PickleType()
         
index d759132678e5fe0a61127e0aee9a2ee7393ca8ec..6551594f32034524bf72367974dd7a3b7746a961 100644 (file)
@@ -56,6 +56,7 @@ class UnicodeSchemaTest(TestBase):
                        )
         metadata.create_all()
 
+    @engines.close_first
     def teardown(self):
         if metadata.tables:
             t3.delete().execute()
@@ -125,11 +126,11 @@ class EscapesDefaultsTest(testing.TestBase):
             # reset the identifier preparer, so that we can force it to cache
             # a unicode identifier
             engine.dialect.identifier_preparer = engine.dialect.preparer(engine.dialect)
-            select([column(u'special_col')]).select_from(t1).execute()
+            select([column(u'special_col')]).select_from(t1).execute().close()
             assert isinstance(engine.dialect.identifier_preparer.format_sequence(Sequence('special_col')), unicode)
             
             # now execute, run the sequence.  it should run in u"Special_col.nextid" or similar as 
-            # a unicode object; cx_oracle asserts that this is None or a String (postgres lets it pass thru).
+            # a unicode object; cx_oracle asserts that this is None or a String (postgresql lets it pass thru).
             # ensure that base.DefaultRunner is encoding.
             t1.insert().execute(data='foo')
         finally:
index 5203bd866a41c3cf573351acedfc11917a42b863..126d2c5684b34505987bb7035bc9913273fe52ff 100644 (file)
@@ -1,7 +1,7 @@
 """mapper.py - defines mappers for domain objects, mapping operations"""
 
-import tables, user
-from blog import *
+from test.zblog import tables, user
+from test.zblog.blog import *
 from sqlalchemy import *
 from sqlalchemy.orm import *
 import sqlalchemy.util as util
index 36c7aeb8b19cc67a9b8621e5e7fab444766c976f..4907259e18df0e987f44e91e53ec95925a8bc1a9 100644 (file)
@@ -1,12 +1,12 @@
 """application table metadata objects are described here."""
 
 from sqlalchemy import *
-
+from sqlalchemy.test.schema import Table, Column
 
 metadata = MetaData()
 
 users = Table('users', metadata,
-    Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True),
+    Column('user_id', Integer, primary_key=True, test_needs_autoincrement=True),
     Column('user_name', String(30), nullable=False),
     Column('fullname', String(100), nullable=False),
     Column('password', String(40), nullable=False),
@@ -14,14 +14,14 @@ users = Table('users', metadata,
     )
 
 blogs = Table('blogs', metadata,
-    Column('blog_id', Integer, Sequence('blog_id_seq', optional=True), primary_key=True),
+    Column('blog_id', Integer, primary_key=True, test_needs_autoincrement=True),
     Column('owner_id', Integer, ForeignKey('users.user_id'), nullable=False),
     Column('name', String(100), nullable=False),
     Column('description', String(500))
     )
 
 posts = Table('posts', metadata,
-    Column('post_id', Integer, Sequence('post_id_seq', optional=True), primary_key=True),
+    Column('post_id', Integer, primary_key=True, test_needs_autoincrement=True),
     Column('blog_id', Integer, ForeignKey('blogs.blog_id'), nullable=False),
     Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False),
     Column('datetime', DateTime, nullable=False),
@@ -31,7 +31,7 @@ posts = Table('posts', metadata,
     )
 
 topics = Table('topics', metadata,
-    Column('topic_id', Integer, primary_key=True),
+    Column('topic_id', Integer, primary_key=True, test_needs_autoincrement=True),
     Column('keyword', String(50), nullable=False),
     Column('description', String(500))
    )
@@ -43,7 +43,7 @@ topic_xref = Table('topic_post_xref', metadata,
    )
 
 comments = Table('comments', metadata,
-    Column('comment_id', Integer, primary_key=True),
+    Column('comment_id', Integer, primary_key=True, test_needs_autoincrement=True),
     Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False),
     Column('post_id', Integer, ForeignKey('posts.post_id'), nullable=False),
     Column('datetime', DateTime, nullable=False),
index 8170766cb253bcb5f28e5359fbfd0cf599691447..5e46c1cebcf7f448681d38d3fbe28987ae166039 100644 (file)
@@ -1,9 +1,9 @@
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.test import *
-import mappers, tables
-from user import *
-from blog import *
+from test.zblog import mappers, tables
+from test.zblog.user import *
+from test.zblog.blog import *
 
 
 class ZBlogTest(TestBase, AssertsExecutionResults):
index 0a13002cd8464f392c9215a4ed57859ac2a20bf0..30f1e3da16981bae45b62ddaed28c80a85e2b21b 100644 (file)
@@ -14,9 +14,9 @@ groups = [user, administrator]
 
 def cryptpw(password, salt=None):
     if salt is None:
-        salt = string.join([chr(random.randint(ord('a'), ord('z'))),
-                            chr(random.randint(ord('a'), ord('z')))],'')
-    return sha(password + salt).hexdigest()
+        salt = "".join([chr(random.randint(ord('a'), ord('z'))),
+                            chr(random.randint(ord('a'), ord('z')))])
+    return sha((password+ salt).encode('ascii')).hexdigest()
 
 def checkpw(password, dbpw):
     return cryptpw(password, dbpw[:2]) == dbpw