]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
first merge from the hg repo. may need cleanup/refreshing
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Jan 2009 17:33:53 +0000 (17:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Jan 2009 17:33:53 +0000 (17:33 +0000)
48 files changed:
doc/build/copyright.rst
lib/sqlalchemy/connectors/__init__.py [new file with mode: 0644]
lib/sqlalchemy/connectors/pyodbc.py [new file with mode: 0644]
lib/sqlalchemy/connectors/zxJDBC.py [new file with mode: 0644]
lib/sqlalchemy/databases/__init__.py
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/maxdb.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/sqlite.py [deleted file]
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/dialects/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgres/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgres/base.py [moved from lib/sqlalchemy/databases/postgres.py with 63% similarity]
lib/sqlalchemy/dialects/postgres/psycopg2.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/pysqlite.py [new file with mode: 0644]
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/ddl.py [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/url.py
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/dialect/postgres.py
test/dialect/sqlite.py
test/engine/ddlevents.py
test/engine/reflection.py
test/ext/alltests.py
test/ext/declarative.py
test/orm/unitofwork.py
test/sql/constraints.py
test/sql/select.py
test/sql/testtypes.py
test/testlib/engines.py
test/testlib/testing.py

index 227a54c9c83b8bba6c31259fc3aa429f8de04ddf..501b4ee757b4df081a09e9e84522248695b868f4 100644 (file)
@@ -4,7 +4,7 @@ Appendix:  Copyright
 
 This is the MIT license: `<http://www.opensource.org/licenses/mit-license.php>`_
 
-Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
+Copyright (c) 2005, 2006, 2007, 2008, 2009 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
 Bayer.
 
 Permission is hereby granted, free of charge, to any person obtaining a copy of this
diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py
new file mode 100644 (file)
index 0000000..e69de29
index 6588be0ae71410a180f9e13912c37c50e4a2088a..7f124d7dbda23bd5353582fff51b3d865c6ced40 100644 (file)
@@ -13,7 +13,5 @@ __all__ = (
     'mssql',
     'mysql',
     'oracle',
-    'postgres',
-    'sqlite',
     'sybase',
     )
index 67af4a7a4a6808b0adb57a87c57c3fa81d875510..de4af6bcb710d23c7f74fbfbc9f12c85ce607870 100644 (file)
@@ -46,7 +46,7 @@ class AcTinyInteger(types.Integer):
     def get_col_spec(self):
         return "TINYINT"
 
-class AcSmallInteger(types.Smallinteger):
+class AcSmallInteger(types.SmallInteger):
     def get_col_spec(self):
         return "SMALLINT"
 
@@ -155,7 +155,7 @@ class AccessDialect(default.DefaultDialect):
     colspecs = {
         types.Unicode : AcUnicode,
         types.Integer : AcInteger,
-        types.Smallinteger: AcSmallInteger,
+        types.SmallInteger: AcSmallInteger,
         types.Numeric : AcNumeric,
         types.Float : AcFloat,
         types.DateTime : AcDateTime,
@@ -327,7 +327,7 @@ class AccessDialect(default.DefaultDialect):
         return names
 
 
-class AccessCompiler(compiler.DefaultCompiler):
+class AccessCompiler(compiler.SQLCompiler):
     def visit_select_precolumns(self, select):
         """Access puts TOP, it's version of LIMIT here """
         s = select.distinct and "DISTINCT " or ""
index 6b1af9fab08bea445365acd2d9a6c803e05021f9..f00aa963ee6b5260bc007d2a0b460f4432c17c62 100644 (file)
@@ -150,7 +150,7 @@ class FBInteger(sqltypes.Integer):
         return "INTEGER"
 
 
-class FBSmallInteger(sqltypes.Smallinteger):
+class FBSmallInteger(sqltypes.SmallInteger):
     """Handle ``SMALLINT`` datatype."""
 
     def get_col_spec(self):
@@ -231,7 +231,7 @@ class FBBoolean(sqltypes.Boolean):
 
 colspecs = {
     sqltypes.Integer : FBInteger,
-    sqltypes.Smallinteger : FBSmallInteger,
+    sqltypes.SmallInteger : FBSmallInteger,
     sqltypes.Numeric : FBNumeric,
     sqltypes.Float : FBFloat,
     sqltypes.DateTime : FBDateTime,
@@ -564,12 +564,12 @@ def _substring(s, start, length=None):
         return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
 
 
-class FBCompiler(sql.compiler.DefaultCompiler):
+class FBCompiler(sql.compiler.SQLCompiler):
     """Firebird specific idiosincrasies"""
 
     # Firebird lacks a builtin modulo operator, but there is
     # an equivalent function in the ib_udf library.
-    operators = sql.compiler.DefaultCompiler.operators.copy()
+    operators = sql.compiler.SQLCompiler.operators.copy()
     operators.update({
         sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
         })
@@ -581,7 +581,7 @@ class FBCompiler(sql.compiler.DefaultCompiler):
         else:
             return self.process(alias.original, **kwargs)
 
-    functions = sql.compiler.DefaultCompiler.functions.copy()
+    functions = sql.compiler.SQLCompiler.functions.copy()
     functions['substring'] = _substring
 
     def function_argspec(self, func):
index 4476af3b9c25ea85c2fad39528f3f6014d6b13cc..ad9dfd9bcec9a5d430eafba0628d17318dccda6e 100644 (file)
@@ -51,7 +51,7 @@ class InfoInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
 
-class InfoSmallInteger(sqltypes.Smallinteger):
+class InfoSmallInteger(sqltypes.SmallInteger):
     def get_col_spec(self):
         return "SMALLINT"
 
@@ -141,7 +141,7 @@ class InfoBoolean(sqltypes.Boolean):
 
 colspecs = {
     sqltypes.Integer : InfoInteger,
-    sqltypes.Smallinteger : InfoSmallInteger,
+    sqltypes.SmallInteger : InfoSmallInteger,
     sqltypes.Numeric : InfoNumeric,
     sqltypes.Float : InfoNumeric,
     sqltypes.DateTime : InfoDateTime,
@@ -352,7 +352,7 @@ class InfoDialect(default.DefaultDialect):
         for cons_name, cons_type, local_column in rows:
             table.primary_key.add( table.c[local_column] )
 
-class InfoCompiler(compiler.DefaultCompiler):
+class InfoCompiler(compiler.SQLCompiler):
     """Info compiler modifies the lexical structure of Select statements to work under
     non-ANSI configured Oracle databases, if the use_ansi flag is False."""
 
@@ -360,7 +360,7 @@ class InfoCompiler(compiler.DefaultCompiler):
         self.limit = 0
         self.offset = 0
 
-        compiler.DefaultCompiler.__init__( self , *args, **kwargs )
+        compiler.SQLCompiler.__init__( self , *args, **kwargs )
 
     def default_from(self):
         return " from systables where tabname = 'systables' "
@@ -393,7 +393,7 @@ class InfoCompiler(compiler.DefaultCompiler):
             if ( __label(c) not in a ):
                 select.append_column( c )
 
-        return compiler.DefaultCompiler.visit_select(self, select)
+        return compiler.SQLCompiler.visit_select(self, select)
 
     def limit_clause(self, select):
         return ""
@@ -406,7 +406,7 @@ class InfoCompiler(compiler.DefaultCompiler):
         elif func.name.lower() in ( 'current_timestamp' , 'now' ):
             return "CURRENT YEAR TO SECOND"
         else:
-            return compiler.DefaultCompiler.visit_function( self , func )
+            return compiler.SQLCompiler.visit_function( self , func )
 
     def visit_clauselist(self, list, **kwargs):
         return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
index 693295054e4a3fd5aebb98f79a4cfa07c3ff7eea..6e521297fc108ebaf975649ca69138ce4997014c 100644 (file)
@@ -344,7 +344,7 @@ class MaxBlob(sqltypes.Binary):
 
 colspecs = {
     sqltypes.Integer: MaxInteger,
-    sqltypes.Smallinteger: MaxSmallInteger,
+    sqltypes.SmallInteger: MaxSmallInteger,
     sqltypes.Numeric: MaxNumeric,
     sqltypes.Float: MaxFloat,
     sqltypes.DateTime: MaxTimestamp,
@@ -717,8 +717,8 @@ class MaxDBDialect(default.DefaultDialect):
         return found
 
 
-class MaxDBCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
+class MaxDBCompiler(compiler.SQLCompiler):
+    operators = compiler.SQLCompiler.operators.copy()
     operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
 
     function_conversion = {
index 7d23c5b27365cc700b42acc032ff92ad6c83dd58..dda0fddd24703c2097ab9438fc3d5776ab2f94ed 100644 (file)
@@ -923,7 +923,7 @@ class MSSQLDialect(default.DefaultDialect):
     colspecs = {
         sqltypes.Unicode : MSNVarchar,
         sqltypes.Integer : MSInteger,
-        sqltypes.Smallinteger: MSSmallInteger,
+        sqltypes.SmallInteger: MSSmallInteger,
         sqltypes.Numeric : MSNumeric,
         sqltypes.Float : MSFloat,
         sqltypes.DateTime : MSDateTime,
@@ -1445,14 +1445,14 @@ dialect_mapping = {
     }
 
 
-class MSSQLCompiler(compiler.DefaultCompiler):
+class MSSQLCompiler(compiler.SQLCompiler):
     operators = compiler.OPERATORS.copy()
     operators.update({
         sql_operators.concat_op: '+',
         sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
     })
 
-    functions = compiler.DefaultCompiler.functions.copy()
+    functions = compiler.SQLCompiler.functions.copy()
     functions.update (
         {
             sql_functions.now: 'CURRENT_TIMESTAMP',
@@ -1478,7 +1478,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
                     if not self.dialect.has_window_funcs:
                         raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
             return s
-        return compiler.DefaultCompiler.get_select_precolumns(self, select)
+        return compiler.SQLCompiler.get_select_precolumns(self, select)
 
     def limit_clause(self, select):
         # Limit in mssql is after the select keyword
@@ -1506,7 +1506,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
                 limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
             return self.process(limitselect, iswrapper=True, **kwargs)
         else:
-            return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
+            return compiler.SQLCompiler.visit_select(self, select, **kwargs)
 
     def _schema_aliased_table(self, table):
         if getattr(table, 'schema', None) is not None:
index 3d71bb72324fbfb58b3b9b2fd334f2f9fd798be3..ac4e64b59740f280a38f864cfb41a212295cfc1e 100644 (file)
@@ -630,7 +630,7 @@ class MSTinyInteger(MSInteger):
             return self._extend("TINYINT")
 
 
-class MSSmallInteger(sqltypes.Smallinteger, MSInteger):
+class MSSmallInteger(sqltypes.SmallInteger, MSInteger):
     """MySQL SMALLINTEGER type."""
 
     def __init__(self, display_width=None, **kw):
@@ -1363,7 +1363,7 @@ class MSBoolean(sqltypes.Boolean):
 
 colspecs = {
     sqltypes.Integer: MSInteger,
-    sqltypes.Smallinteger: MSSmallInteger,
+    sqltypes.SmallInteger: MSSmallInteger,
     sqltypes.Numeric: MSNumeric,
     sqltypes.Float: MSFloat,
     sqltypes.DateTime: MSDateTime,
@@ -1901,14 +1901,14 @@ class _MySQLPythonRowProxy(object):
             return item
 
 
-class MySQLCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
+class MySQLCompiler(compiler.SQLCompiler):
+    operators = compiler.SQLCompiler.operators.copy()
     operators.update({
         sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
         sql_operators.mod: '%%',
         sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
     })
-    functions = compiler.DefaultCompiler.functions.copy()
+    functions = compiler.SQLCompiler.functions.copy()
     functions.update ({
         sql_functions.random: 'rand%(expr)s',
         "utc_timestamp":"UTC_TIMESTAMP"
@@ -2013,7 +2013,8 @@ class MySQLCompiler(compiler.DefaultCompiler):
         self.isupdate = True
         colparams = self._get_colparams(update_stmt)
 
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \
+                " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
 
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
index 6749d8e407480b444f983509f9c7292c98c2f6f5..b0ec6115b2b6a768811c5038f8f53f37dea20b5d 100644 (file)
@@ -143,7 +143,7 @@ class OracleInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
 
-class OracleSmallInteger(sqltypes.Smallinteger):
+class OracleSmallInteger(sqltypes.SmallInteger):
     def get_col_spec(self):
         return "SMALLINT"
 
@@ -286,7 +286,7 @@ class OracleBoolean(sqltypes.Boolean):
 
 colspecs = {
     sqltypes.Integer : OracleInteger,
-    sqltypes.Smallinteger : OracleSmallInteger,
+    sqltypes.SmallInteger : OracleSmallInteger,
     sqltypes.Numeric : OracleNumeric,
     sqltypes.Float : OracleNumeric,
     sqltypes.DateTime : OracleDateTime,
@@ -698,13 +698,13 @@ class _OuterJoinColumn(sql.ClauseElement):
     def __init__(self, column):
         self.column = column
 
-class OracleCompiler(compiler.DefaultCompiler):
+class OracleCompiler(compiler.SQLCompiler):
     """Oracle compiler modifies the lexical structure of Select
     statements to work under non-ANSI configured Oracle databases, if
     the use_ansi flag is False.
     """
 
-    operators = compiler.DefaultCompiler.operators.copy()
+    operators = compiler.SQLCompiler.operators.copy()
     operators.update(
         {
             sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
@@ -712,7 +712,7 @@ class OracleCompiler(compiler.DefaultCompiler):
         }
     )
 
-    functions = compiler.DefaultCompiler.functions.copy()
+    functions = compiler.SQLCompiler.functions.copy()
     functions.update (
         {
             sql_functions.now : 'CURRENT_TIMESTAMP'
@@ -736,7 +736,7 @@ class OracleCompiler(compiler.DefaultCompiler):
 
     def visit_join(self, join, **kwargs):
         if self.dialect.use_ansi:
-            return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
+            return compiler.SQLCompiler.visit_join(self, join, **kwargs)
         else:
             return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
 
@@ -846,7 +846,7 @@ class OracleCompiler(compiler.DefaultCompiler):
                      select = offsetselect
 
         kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
-        return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
+        return compiler.SQLCompiler.visit_select(self, select, **kwargs)
 
     def limit_clause(self, select):
         return ""
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
deleted file mode 100644 (file)
index 8b46132..0000000
+++ /dev/null
@@ -1,619 +0,0 @@
-# sqlite.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""Support for the SQLite database.
-
-Driver
-------
-
-When using Python 2.5 and above, the built in ``sqlite3`` driver is 
-already installed and no additional installation is needed.  Otherwise,
-the ``pysqlite2`` driver needs to be present.  This is the same driver as
-``sqlite3``, just with a different name.
-
-The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
-is loaded.  This allows an explicitly installed pysqlite driver to take
-precedence over the built in one.   As with all dialects, a specific 
-DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control 
-this explicitly::
-
-    from sqlite3 import dbapi2 as sqlite
-    e = create_engine('sqlite:///file.db', module=sqlite)
-
-Full documentation on pysqlite is available at:
-`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
-
-Connect Strings
----------------
-
-The file specification for the SQLite database is taken as the "database" portion of
-the URL.  Note that the format of a url is::
-
-    driver://user:pass@host/database
-    
-This means that the actual filename to be used starts with the characters to the
-**right** of the third slash.   So connecting to a relative filepath looks like::
-
-    # relative path
-    e = create_engine('sqlite:///path/to/database.db')
-    
-An absolute path, which is denoted by starting with a slash, means you need **four**
-slashes::
-
-    # absolute path
-    e = create_engine('sqlite:////path/to/database.db')
-
-To use a Windows path, regular drive specifications and backslashes can be used.  
-Double backslashes are probably needed::
-
-    # absolute path on Windows
-    e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
-
-The sqlite ``:memory:`` identifier is the default if no filepath is present.  Specify
-``sqlite://`` and nothing else::
-
-    # in-memory database
-    e = create_engine('sqlite://')
-
-Threading Behavior
-------------------
-
-Pysqlite connections do not support being moved between threads, unless
-the ``check_same_thread`` Pysqlite flag is set to ``False``.  In addition,
-when using an in-memory SQLite database, the full database exists only within 
-the scope of a single connection.  It is reported that an in-memory
-database does not support being shared between threads regardless of the 
-``check_same_thread`` flag - which means that a multithreaded
-application **cannot** share data from a ``:memory:`` database across threads
-unless access to the connection is limited to a single worker thread which communicates
-through a queueing mechanism to concurrent threads.
-
-To provide a default which accomodates SQLite's default threading capabilities
-somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
-be used by default.  This pool maintains a single SQLite connection per thread
-that is held open up to a count of five concurrent threads.  When more than five threads
-are used, a cleanup mechanism will dispose of excess unused connections.   
-
-Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
-
- * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
-   application using an in-memory database, assuming the threading issues inherent in 
-   pysqlite are somehow accomodated for.  This pool holds persistently onto a single connection
-   which is never closed, and is returned for all requests.
-   
- * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
-   makes use of a file-based sqlite database.  This pool disables any actual "pooling"
-   behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
-   and :func:`close()` methods.  SQLite can "connect" to a particular file with very high 
-   efficiency, so this option may actually perform better without the extra overhead
-   of :class:`SingletonThreadPool`.  NullPool will of course render a ``:memory:`` connection
-   useless since the database would be lost as soon as the connection is "returned" to the pool.
-
-Date and Time Types
--------------------
-
-SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide 
-out of the box functionality for translating values between Python `datetime` objects
-and a SQLite-supported format.  SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
-and related types provide date formatting and parsing functionality when SQlite is used.
-The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`.
-These types represent dates and times as ISO formatted strings, which also nicely
-support ordering.   There's no reliance on typical "libc" internals for these functions
-so historical dates are fully supported.
-
-Unicode
--------
-
-In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's 
-default behavior regarding Unicode is that all strings are returned as Python unicode objects
-in all cases.  So even if the :class:`~sqlalchemy.types.Unicode` type is 
-*not* used, you will still always receive unicode data back from a result set.  It is 
-**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
-to represent strings, since it will raise a warning if a non-unicode Python string is 
-passed from the user application.  Mixing the usage of non-unicode objects with returned unicode objects can
-quickly create confusion, particularly when using the ORM as internal data is not 
-always represented by an actual database result string.
-
-"""
-
-
-import datetime, re, time
-
-from sqlalchemy import sql, schema, exc, pool, DefaultClause
-from sqlalchemy.engine import default
-import sqlalchemy.types as sqltypes
-import sqlalchemy.util as util
-from sqlalchemy.sql import compiler, functions as sql_functions
-from types import NoneType
-
-class SLNumeric(sqltypes.Numeric):
-    def bind_processor(self, dialect):
-        type_ = self.asdecimal and str or float
-        def process(value):
-            if value is not None:
-                return type_(value)
-            else:
-                return value
-        return process
-
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SLFloat(sqltypes.Float):
-    def bind_processor(self, dialect):
-        type_ = self.asdecimal and str or float
-        def process(value):
-            if value is not None:
-                return type_(value)
-            else:
-                return value
-        return process
-
-    def get_col_spec(self):
-        return "FLOAT"
-    
-class SLInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class SLSmallInteger(sqltypes.Smallinteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class DateTimeMixin(object):
-    def _bind_processor(self, format, elements):
-        def process(value):
-            if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
-                raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
-            elif value is not None:
-                return format % tuple([getattr(value, attr, 0) for attr in elements])
-            else:
-                return None
-        return process
-
-    def _result_processor(self, fn, regexp):
-        def process(value):
-            if value is not None:
-                return fn(*[int(x or 0) for x in regexp.match(value).groups()])
-            else:
-                return None
-        return process
-
-class SLDateTime(DateTimeMixin, sqltypes.DateTime):
-    __legacy_microseconds__ = False
-
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-    def bind_processor(self, dialect):
-        if self.__legacy_microseconds__:
-            return self._bind_processor(
-                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", 
-                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
-                        )
-        else:
-            return self._bind_processor(
-                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", 
-                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
-                        )
-
-    _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
-    def result_processor(self, dialect):
-        return self._result_processor(datetime.datetime, self._reg)
-
-class SLDate(DateTimeMixin, sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
-
-    def bind_processor(self, dialect):
-        return self._bind_processor(
-                        "%4.4d-%2.2d-%2.2d", 
-                        ("year", "month", "day")
-                )
-
-    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
-    def result_processor(self, dialect):
-        return self._result_processor(datetime.date, self._reg)
-
-class SLTime(DateTimeMixin, sqltypes.Time):
-    __legacy_microseconds__ = False
-
-    def get_col_spec(self):
-        return "TIME"
-
-    def bind_processor(self, dialect):
-        if self.__legacy_microseconds__:
-            return self._bind_processor(
-                            "%2.2d:%2.2d:%2.2d.%s", 
-                            ("hour", "minute", "second", "microsecond")
-                    )
-        else:
-            return self._bind_processor(
-                            "%2.2d:%2.2d:%2.2d.%06d", 
-                            ("hour", "minute", "second", "microsecond")
-                    )
-
-    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
-    def result_processor(self, dialect):
-        return self._result_processor(datetime.time, self._reg)
-
-class SLUnicodeMixin(object):
-    def bind_processor(self, dialect):
-        if self.convert_unicode or dialect.convert_unicode:
-            if self.assert_unicode is None:
-                assert_unicode = dialect.assert_unicode
-            else:
-                assert_unicode = self.assert_unicode
-                
-            if not assert_unicode:
-                return None
-                
-            def process(value):
-                if not isinstance(value, (unicode, NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
-                        return value
-                    else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
-            return process
-        else:
-            return None
-
-    def result_processor(self, dialect):
-        return None
-    
-class SLText(SLUnicodeMixin, sqltypes.Text):
-    def get_col_spec(self):
-        return "TEXT"
-
-class SLString(SLUnicodeMixin, sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLChar(SLUnicodeMixin, sqltypes.CHAR):
-    def get_col_spec(self):
-        return "CHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLBinary(sqltypes.Binary):
-    def get_col_spec(self):
-        return "BLOB"
-
-class SLBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BOOLEAN"
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value and 1 or 0
-        return process
-
-    def result_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-colspecs = {
-    sqltypes.Binary: SLBinary,
-    sqltypes.Boolean: SLBoolean,
-    sqltypes.CHAR: SLChar,
-    sqltypes.Date: SLDate,
-    sqltypes.DateTime: SLDateTime,
-    sqltypes.Float: SLFloat,
-    sqltypes.Integer: SLInteger,
-    sqltypes.NCHAR: SLChar,
-    sqltypes.Numeric: SLNumeric,
-    sqltypes.Smallinteger: SLSmallInteger,
-    sqltypes.String: SLString,
-    sqltypes.Text: SLText,
-    sqltypes.Time: SLTime,
-}
-
-ischema_names = {
-    'BLOB': SLBinary,
-    'BOOL': SLBoolean,
-    'BOOLEAN': SLBoolean,
-    'CHAR': SLChar,
-    'DATE': SLDate,
-    'DATETIME': SLDateTime,
-    'DECIMAL': SLNumeric,
-    'FLOAT': SLNumeric,
-    'INT': SLInteger,
-    'INTEGER': SLInteger,
-    'NUMERIC': SLNumeric,
-    'REAL': SLNumeric,
-    'SMALLINT': SLSmallInteger,
-    'TEXT': SLText,
-    'TIME': SLTime,
-    'TIMESTAMP': SLDateTime,
-    'VARCHAR': SLString,
-}
-
-class SQLiteExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self):
-        if self.compiled.isinsert and not self.executemany:
-            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
-                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
-class SQLiteDialect(default.DefaultDialect):
-    name = 'sqlite'
-    supports_alter = False
-    supports_unicode_statements = True
-    default_paramstyle = 'qmark'
-    supports_default_values = True
-    supports_empty_insert = False
-
-    def __init__(self, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-        def vers(num):
-            return tuple([int(x) for x in num.split('.')])
-        if self.dbapi is not None:
-            sqlite_ver = self.dbapi.version_info
-            if sqlite_ver < (2, 1, '3'):
-                util.warn(
-                    ("The installed version of pysqlite2 (%s) is out-dated "
-                     "and will cause errors in some cases.  Version 2.1.3 "
-                     "or greater is recommended.") %
-                    '.'.join([str(subver) for subver in sqlite_ver]))
-            if self.dbapi.sqlite_version_info < (3, 3, 8):
-                self.supports_default_values = False
-        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
-
-    def dbapi(cls):
-        try:
-            from pysqlite2 import dbapi2 as sqlite
-        except ImportError, e:
-            try:
-                from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
-            except ImportError:
-                raise e
-        return sqlite
-    dbapi = classmethod(dbapi)
-
-    def server_version_info(self, connection):
-        return self.dbapi.sqlite_version_info
-
-    def create_connect_args(self, url):
-        if url.username or url.password or url.host or url.port:
-            raise exc.ArgumentError(
-                "Invalid SQLite URL: %s\n"
-                "Valid SQLite URL forms are:\n"
-                " sqlite:///:memory: (or, sqlite://)\n"
-                " sqlite:///relative/path/to/file.db\n"
-                " sqlite:////absolute/path/to/file.db" % (url,))
-        filename = url.database or ':memory:'
-
-        opts = url.query.copy()
-        util.coerce_kw_type(opts, 'timeout', float)
-        util.coerce_kw_type(opts, 'isolation_level', str)
-        util.coerce_kw_type(opts, 'detect_types', int)
-        util.coerce_kw_type(opts, 'check_same_thread', bool)
-        util.coerce_kw_type(opts, 'cached_statements', int)
-
-        return ([filename], opts)
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
-    def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
-
-    def table_names(self, connection, schema):
-        if schema is not None:
-            qschema = self.identifier_preparer.quote_identifier(schema)
-            master = '%s.sqlite_master' % qschema
-            s = ("SELECT name FROM %s "
-                 "WHERE type='table' ORDER BY name") % (master,)
-            rs = connection.execute(s)
-        else:
-            try:
-                s = ("SELECT name FROM "
-                     " (SELECT * FROM sqlite_master UNION ALL "
-                     "  SELECT * FROM sqlite_temp_master) "
-                     "WHERE type='table' ORDER BY name")
-                rs = connection.execute(s)
-            except exc.DBAPIError:
-                raise
-                s = ("SELECT name FROM sqlite_master "
-                     "WHERE type='table' ORDER BY name")
-                rs = connection.execute(s)
-
-        return [row[0] for row in rs]
-
-    def has_table(self, connection, table_name, schema=None):
-        quote = self.identifier_preparer.quote_identifier
-        if schema is not None:
-            pragma = "PRAGMA %s." % quote(schema)
-        else:
-            pragma = "PRAGMA "
-        qtable = quote(table_name)
-        cursor = connection.execute("%stable_info(%s)" % (pragma, qtable))
-        row = cursor.fetchone()
-
-        # consume remaining rows, to work around
-        # http://www.sqlite.org/cvstrac/tktview?tn=1884
-        while cursor.fetchone() is not None:
-            pass
-
-        return (row is not None)
-
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-        if table.schema is None:
-            pragma = "PRAGMA "
-        else:
-            pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema)
-        qtable = preparer.format_table(table, False)
-
-        c = connection.execute("%stable_info(%s)" % (pragma, qtable))
-        found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-
-            found_table = True
-            (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
-            name = re.sub(r'^\"|\"$', '', name)
-            if include_columns and name not in include_columns:
-                continue
-            match = re.match(r'(\w+)(\(.*?\))?', type_)
-            if match:
-                coltype = match.group(1)
-                args = match.group(2)
-            else:
-                coltype = "VARCHAR"
-                args = ''
-
-            try:
-                coltype = ischema_names[coltype]
-            except KeyError:
-                util.warn("Did not recognize type '%s' of column '%s'" %
-                          (coltype, name))
-                coltype = sqltypes.NullType
-
-            if args is not None:
-                args = re.findall(r'(\d+)', args)
-                coltype = coltype(*[int(a) for a in args])
-
-            colargs = []
-            if has_default:
-                colargs.append(DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
-
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
-
-        c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
-        fks = {}
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
-            tablename = re.sub(r'^\"|\"$', '', tablename)
-            localcol = re.sub(r'^\"|\"$', '', localcol)
-            remotecol = re.sub(r'^\"|\"$', '', remotecol)
-            try:
-                fk = fks[constraint_name]
-            except KeyError:
-                fk = ([], [])
-                fks[constraint_name] = fk
-
-            # look up the table based on the given table's engine, not 'self',
-            # since it could be a ProxyEngine
-            remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
-            constrained_column = table.c[localcol].name
-            refspec = ".".join([tablename, remotecol])
-            if constrained_column not in fk[0]:
-                fk[0].append(constrained_column)
-            if refspec not in fk[1]:
-                fk[1].append(refspec)
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True))
-        # check for UNIQUE indexes
-        c = connection.execute("%sindex_list(%s)" % (pragma, qtable))
-        unique_indexes = []
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            if (row[2] == 1):
-                unique_indexes.append(row[1])
-        # loop thru unique indexes for one that includes the primary key
-        for idx in unique_indexes:
-            c = connection.execute("%sindex_info(%s)" % (pragma, idx))
-            cols = []
-            while True:
-                row = c.fetchone()
-                if row is None:
-                    break
-                cols.append(row[2])
-
-
-class SQLiteCompiler(compiler.DefaultCompiler):
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now: 'CURRENT_TIMESTAMP',
-            sql_functions.char_length: 'length%(expr)s'
-        }
-    )
-
-    def visit_cast(self, cast, **kwargs):
-        if self.dialect.supports_cast:
-            return super(SQLiteCompiler, self).visit_cast(cast)
-        else:
-            return self.process(cast.clause)
-
-    def limit_clause(self, select):
-        text = ""
-        if select._limit is not None:
-            text +=  " \n LIMIT " + str(select._limit)
-        if select._offset is not None:
-            if select._limit is None:
-                text += " \n LIMIT -1"
-            text += " OFFSET " + str(select._offset)
-        else:
-            text += " OFFSET 0"
-        return text
-
-    def for_update_clause(self, select):
-        # sqlite has no "FOR UPDATE" AFAICT
-        return ''
-
-
-class SQLiteSchemaGenerator(compiler.SchemaGenerator):
-
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
-
-class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
-    reserved_words = set([
-        'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
-        'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
-        'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
-        'conflict', 'constraint', 'create', 'cross', 'current_date',
-        'current_time', 'current_timestamp', 'database', 'default',
-        'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
-        'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
-        'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
-        'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
-        'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
-        'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
-        'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
-        'plan', 'pragma', 'primary', 'query', 'raise', 'references',
-        'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
-        'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
-        'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
-        'vacuum', 'values', 'view', 'virtual', 'when', 'where',
-        ])
-
-    def __init__(self, dialect):
-        super(SQLiteIdentifierPreparer, self).__init__(dialect)
-
-dialect = SQLiteDialect
-dialect.poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = SQLiteCompiler
-dialect.schemagenerator = SQLiteSchemaGenerator
-dialect.preparer = SQLiteIdentifierPreparer
-dialect.execution_ctx_cls = SQLiteExecutionContext
index 6007315f264334c502eaf09565bd3005da5d9f0c..0cf0eeaf56dd62a4cba533a071d75abc2165e8dd 100644 (file)
@@ -727,8 +727,8 @@ dialect_mapping = {
     }
 
 
-class SybaseSQLCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
+class SybaseSQLCompiler(compiler.SQLCompiler):
+    operators = compiler.SQLCompiler.operators.copy()
     operators.update({
         sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
     })
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py
new file mode 100644 (file)
index 0000000..075e897
--- /dev/null
@@ -0,0 +1,12 @@
+__all__ = (
+#    'access',
+#    'firebird',
+#    'informix',
+#    'maxdb',
+#    'mssql',
+#    'mysql',
+#    'oracle',
+    'postgres',
+    'sqlite',
+#    'sybase',
+    )
diff --git a/lib/sqlalchemy/dialects/postgres/__init__.py b/lib/sqlalchemy/dialects/postgres/__init__.py
new file mode 100644 (file)
index 0000000..c9ac0e1
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.postgres import base, psycopg2
+
+base.dialect = psycopg2.dialect
\ No newline at end of file
similarity index 63%
rename from lib/sqlalchemy/databases/postgres.py
rename to lib/sqlalchemy/dialects/postgres/base.py
index fe5ffe24a06663de397ef8a10dffac73572723b2..d33a6db935331fb5ec8934a4f91b1d678917a722 100644 (file)
@@ -4,90 +4,6 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Support for the PostgreSQL database.
-
-Driver
-------
-
-The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
-The dialect has several behaviors  which are specifically tailored towards compatibility 
-with this module.
-
-Note that psycopg1 is **not** supported.
-
-Connecting
-----------
-
-URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`.
-
-Postgres-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
-
-* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
-  this feature.  What this essentially means from a psycopg2 point of view is that the cursor is 
-  created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
-  are not immediately pre-fetched and buffered after statement execution, but are instead left 
-  on the server and only retrieved as needed.    SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
-  uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows 
-  at a time are fetched over the wire to reduce conversational overhead.
-
-Sequences/SERIAL
-----------------
-
-Postgres supports sequences, and SQLAlchemy uses these as the default means of creating
-new primary key values for integer-based primary key columns.   When creating tables, 
-SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, 
-which generates a sequence corresponding to the column and associated with it based on
-a naming convention.
-
-To specify a specific named sequence to be used for primary key generation, use the
-:func:`~sqlalchemy.schema.Sequence` construct::
-
-    Table('sometable', metadata, 
-            Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
-        )
-
-Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
-having the "last insert identifier" available, the sequence is executed independently
-beforehand and the new value is retrieved, to be used in the subsequent insert.  Note
-that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using 
-"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
-is used.
-
-Postgres 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports 
-as well.  A future release of SQLA will use this feature by default in lieu of 
-sequence pre-execution in order to retrieve new primary key values, when available.
-
-INSERT/UPDATE...RETURNING
--------------------------
-
-The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, 
-but must be explicitly enabled on a per-statement basis::
-
-    # INSERT..RETURNING
-    result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
-        values(name='foo')
-    print result.fetchall()
-    
-    # UPDATE..RETURNING
-    result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
-        where(table.c.name=='foo').values(name='bar')
-    print result.fetchall()
-
-Indexes
--------
-
-PostgreSQL supports partial indexes. To create them pass a postgres_where
-option to the Index constructor::
-
-  Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
-
-Transactions
-------------
-
-The Postgres dialect fully supports SAVEPOINT and two-phase commit operations.
-
-
-"""
 
 import decimal, random, re, string
 
@@ -99,101 +15,23 @@ from sqlalchemy import types as sqltypes
 
 
 class PGInet(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "INET"
+    __visit_name__ = "INET"
 
 class PGCidr(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "CIDR"
+    __visit_name__ = "CIDR"
 
 class PGMacAddr(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "MACADDR"
-
-class PGNumeric(sqltypes.Numeric):
-    def get_col_spec(self):
-        if not self.precision:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
-        if self.asdecimal:
-            return None
-        else:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-
-class PGFloat(sqltypes.Float):
-    def get_col_spec(self):
-        if not self.precision:
-            return "FLOAT"
-        else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class PGInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class PGSmallInteger(sqltypes.Smallinteger):
-    def get_col_spec(self):
-        return "SMALLINT"
-
-class PGBigInteger(PGInteger):
-    def get_col_spec(self):
-        return "BIGINT"
-
-class PGDateTime(sqltypes.DateTime):
-    def get_col_spec(self):
-        return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PGDate(sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
+    __visit_name__ = "MACADDR"
 
-class PGTime(sqltypes.Time):
-    def get_col_spec(self):
-        return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+class PGBigInteger(sqltypes.Integer):
+    __visit_name__ = "BIGINT"
 
 class PGInterval(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "INTERVAL"
-
-class PGText(sqltypes.Text):
-    def get_col_spec(self):
-        return "TEXT"
-
-class PGString(sqltypes.String):
-    def get_col_spec(self):
-        if self.length:
-            return "VARCHAR(%(length)d)" % {'length' : self.length}
-        else:
-            return "VARCHAR"
-
-class PGChar(sqltypes.CHAR):
-    def get_col_spec(self):
-        if self.length:
-            return "CHAR(%(length)d)" % {'length' : self.length}
-        else:
-            return "CHAR"
-
-class PGBinary(sqltypes.Binary):
-    def get_col_spec(self):
-        return "BYTEA"
-
-class PGBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BOOLEAN"
+    __visit_name__ = 'INTERVAL'
 
 class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
+    __visit_name__ = 'ARRAY'
+    
     def __init__(self, item_type, mutable=True):
         if isinstance(item_type, type):
             item_type = item_type()
@@ -251,114 +89,233 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
                         return item
             return [convert_item(item) for item in value]
         return process
-    def get_col_spec(self):
-        return self.item_type.get_col_spec() + '[]'
-
-colspecs = {
-    sqltypes.Integer : PGInteger,
-    sqltypes.Smallinteger : PGSmallInteger,
-    sqltypes.Numeric : PGNumeric,
-    sqltypes.Float : PGFloat,
-    sqltypes.DateTime : PGDateTime,
-    sqltypes.Date : PGDate,
-    sqltypes.Time : PGTime,
-    sqltypes.String : PGString,
-    sqltypes.Binary : PGBinary,
-    sqltypes.Boolean : PGBoolean,
-    sqltypes.Text : PGText,
-    sqltypes.CHAR: PGChar,
-}
-
-ischema_names = {
-    'integer' : PGInteger,
-    'bigint' : PGBigInteger,
-    'smallint' : PGSmallInteger,
-    'character varying' : PGString,
-    'character' : PGChar,
-    'text' : PGText,
-    'numeric' : PGNumeric,
-    'float' : PGFloat,
-    'real' : PGFloat,
-    'inet': PGInet,
-    'cidr': PGCidr,
-    'macaddr': PGMacAddr,
-    'double precision' : PGFloat,
-    'timestamp' : PGDateTime,
-    'timestamp with time zone' : PGDateTime,
-    'timestamp without time zone' : PGDateTime,
-    'time with time zone' : PGTime,
-    'time without time zone' : PGTime,
-    'date' : PGDate,
-    'time': PGTime,
-    'bytea' : PGBinary,
-    'boolean' : PGBoolean,
-    'interval':PGInterval,
-}
-
-# TODO: filter out 'FOR UPDATE' statements
-SERVER_SIDE_CURSOR_RE = re.compile(
-    r'\s*SELECT',
-    re.I | re.UNICODE)
-
-class PGExecutionContext(default.DefaultExecutionContext):
-    def create_cursor(self):
-        # TODO: coverage for server side cursors + select.for_update()
-        is_server_side = \
-            self.dialect.server_side_cursors and \
-            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) 
-                and not getattr(self.compiled.statement, 'for_update', False)) \
-            or \
-            (
-                (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) 
-                and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
-            )
 
-        self.__is_server_side = is_server_side
-        if is_server_side:
-            # use server-side cursors:
-            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
-            return self._connection.connection.cursor(ident)
+
+
+
+
+class PGCompiler(compiler.SQLCompiler):
+    operators = compiler.SQLCompiler.operators.copy()
+    operators.update(
+        {
+            sql_operators.mod : '%%',
+            sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
+            sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
+            sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
+        }
+    )
+
+    functions = compiler.SQLCompiler.functions.copy()
+    functions.update (
+        {
+            'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
+        }
+    )
+
+    def visit_sequence(self, seq):
+        if seq.optional:
+            return None
+        else:
+            return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+    def post_process_text(self, text):
+        if '%%' in text:
+            util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
+        return text.replace('%', '%%')
+
+    def limit_clause(self, select):
+        text = ""
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
+                text += " \n LIMIT ALL"
+            text += " OFFSET " + str(select._offset)
+        return text
+
+    def get_select_precolumns(self, select):
+        if select._distinct:
+            if isinstance(select._distinct, bool):
+                return "DISTINCT "
+            elif isinstance(select._distinct, (list, tuple)):
+                return "DISTINCT ON (" + ', '.join(
+                    [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
+                )+ ") "
+            else:
+                return "DISTINCT ON (" + unicode(select._distinct) + ") "
+        else:
+            return ""
+
+    def for_update_clause(self, select):
+        if select.for_update == 'nowait':
+            return " FOR UPDATE NOWAIT"
+        else:
+            return super(PGCompiler, self).for_update_clause(select)
+
+    def _append_returning(self, text, stmt):
+        returning_cols = stmt.kwargs['postgres_returning']
+        def flatten_columnlist(collist):
+            for c in collist:
+                if isinstance(c, expression.Selectable):
+                    for co in c.columns:
+                        yield co
+                else:
+                    yield c
+        columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
+        text += ' RETURNING ' + string.join(columns, ', ')
+        return text
+
+    def visit_update(self, update_stmt):
+        text = super(PGCompiler, self).visit_update(update_stmt)
+        if 'postgres_returning' in update_stmt.kwargs:
+            return self._append_returning(text, update_stmt)
+        else:
+            return text
+
+    def visit_insert(self, insert_stmt):
+        text = super(PGCompiler, self).visit_insert(insert_stmt)
+        if 'postgres_returning' in insert_stmt.kwargs:
+            return self._append_returning(text, insert_stmt)
+        else:
+            return text
+
+class PGDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column)
+        if column.primary_key and \
+            len(column.foreign_keys)==0 and \
+            column.autoincrement and \
+            isinstance(column.type, sqltypes.Integer) and \
+            not isinstance(column.type, sqltypes.SmallInteger) and \
+            (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+            if isinstance(column.type, PGBigInteger):
+                colspec += " BIGSERIAL"
+            else:
+                colspec += " SERIAL"
         else:
-            return self._connection.connection.cursor()
+            colspec += " " + self.dialect.type_compiler.process(column.type)
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
+
+    def visit_create_sequence(self, create):
+        return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+            
+    def visit_drop_sequence(self, drop):
+        return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+        
+    def visit_create_index(self, create):
+        preparer = self.preparer
+        index = create.element
+        text = "CREATE "
+        if index.unique:
+            text += "UNIQUE "
+        text += "INDEX %s ON %s (%s)" \
+                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+                       preparer.format_table(index.table),
+                       string.join([preparer.format_column(c) for c in index.columns], ', '))
+                       
+        whereclause = index.kwargs.get('postgres_where', None)
+        if whereclause is not None:
+            compiler = self._compile(whereclause, None)
+            # this might belong to the compiler class
+            inlined_clause = str(compiler) % dict(
+                [(key,bind.value) for key,bind in compiler.binds.iteritems()])
+            text += " WHERE " + inlined_clause
+        return text
+
+class PGDefaultRunner(base.DefaultRunner):
+    def __init__(self, context):
+        base.DefaultRunner.__init__(self, context)
+        # craete cursor which won't conflict with a server-side cursor
+        self.cursor = context._connection.connection.cursor()
     
-    def get_result_proxy(self):
-        if self.__is_server_side:
-            return base.BufferedRowResultProxy(self)
+    def get_column_default(self, column, isinsert=True):
+        if column.primary_key:
+            # pre-execute passive defaults on primary keys
+            if (isinstance(column.server_default, schema.DefaultClause) and
+                column.server_default.arg is not None):
+                return self.execute_string("select %s" % column.server_default.arg)
+            elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+                sch = column.table.schema
+                # TODO: this has to build into the Sequence object so we can get the quoting
+                # logic from it
+                if sch is not None:
+                    exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
+                else:
+                    exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
+                return self.execute_string(exc.encode(self.dialect.encoding))
+
+        return super(PGDefaultRunner, self).get_column_default(column)
+
+    def visit_sequence(self, seq):
+        if not seq.optional:
+            return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
+        else:
+            return None
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_INET(self, type_):
+        return "INET"
+
+    def visit_CIDR(self, type_):
+        return "CIDR"
+
+    def visit_MACADDR(self, type_):
+        return "MACADDR"
+
+    def visit_FLOAT(self, type_):
+        if not type_.precision:
+            return "FLOAT"
         else:
-            return base.ResultProxy(self)
+            return "FLOAT(%(precision)s)" % {'precision': type_.precision}
+
+    def visit_BIGINT(self, type_):
+        return "BIGINT"
+
+    def visit_DATETIME(self, type_):
+        return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+    def visit_TIME(self, type_):
+        return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+    def visit_INTERVAL(self, type_):
+        return "INTERVAL"
+
+    def visit_BINARY(self, type_):
+        return "BYTEA"
+
+    def visit_ARRAY(self, type_):
+        return self.process(type_.item_type) + '[]'
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+    def _unquote_identifier(self, value):
+        if value[0] == self.initial_quote:
+            value = value[1:-1].replace('""','"')
+        return value
 
 class PGDialect(default.DefaultDialect):
     name = 'postgres'
     supports_alter = True
-    supports_unicode_statements = False
     max_identifier_length = 63
     supports_sane_rowcount = True
-    supports_sane_multi_rowcount = False
+    supports_sequences = True
+    sequences_optional = True
     preexecute_pk_sequences = True
     supports_pk_autoincrement = False
-    default_paramstyle = 'pyformat'
     supports_default_values = True
     supports_empty_insert = False
-    
-    def __init__(self, server_side_cursors=False, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-        self.server_side_cursors = server_side_cursors
 
-    def dbapi(cls):
-        import psycopg2 as psycopg
-        return psycopg
-    dbapi = classmethod(dbapi)
+    statement_compiler = PGCompiler
+    ddl_compiler = PGDDLCompiler
+    type_compiler = PGTypeCompiler
+    preparer = PGIdentifierPreparer
+    defaultrunner = PGDefaultRunner
 
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        if 'port' in opts:
-            opts['port'] = int(opts['port'])
-        opts.update(url.query)
-        return ([], opts)
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
 
     def do_begin_twophase(self, connection, xid):
         self.do_begin(connection.connection)
@@ -392,48 +349,46 @@ class PGDialect(default.DefaultDialect):
         resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
         return [row[0] for row in resultset]
 
+    @base.connection_memoize(('dialect', 'default_schema_name'))
     def get_default_schema_name(self, connection):
         return connection.scalar("select current_schema()", None)
-    get_default_schema_name = base.connection_memoize(
-        ('dialect', 'default_schema_name'))(get_default_schema_name)
-
-    def last_inserted_ids(self):
-        if self.context.last_inserted_ids is None:
-            raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
-        else:
-            return self.context.last_inserted_ids
 
     def has_table(self, connection, table_name, schema=None):
         # seems like case gets folded in pg_class...
         if schema is None:
-            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
+            cursor = connection.execute(
+                sql.text("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name""",
+                    bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)]
+                )
+            )
         else:
-            cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
-        return bool( not not cursor.rowcount )
+            cursor = connection.execute(
+                sql.text("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name""",
+                    bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode),
+                        sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] 
+                )
+            )
+        return bool( cursor.rowcount )
 
     def has_sequence(self, connection, sequence_name):
-        cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)})
-        return bool(not not cursor.rowcount)
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.OperationalError):
-            return 'closed the connection' in str(e) or 'connection not open' in str(e)
-        elif isinstance(e, self.dbapi.InterfaceError):
-            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
-        elif isinstance(e, self.dbapi.ProgrammingError):
-            # yes, it really says "losed", not "closed"
-            return "losed the connection unexpectedly" in str(e)
-        else:
-            return False
+        cursor = connection.execute(
+                    sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND "
+                        "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' "
+                        "AND nspname != 'information_schema' AND relname = :seqname)", 
+                        bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)]
+                    ))
+        return bool(cursor.rowcount)
 
     def table_names(self, connection, schema):
-        s = """
-        SELECT relname
-        FROM pg_class c
-        WHERE relkind = 'r'
-          AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
-        """ % locals()
-        return [row[0].decode(self.encoding) for row in connection.execute(s)]
+        result = connection.execute(
+            sql.text(u"""SELECT relname
+                FROM pg_class c
+                WHERE relkind = 'r'
+                AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema,
+                typemap = {'relname':sqltypes.Unicode}
+            )
+        )
+        return [row[0] for row in result]
 
     def server_version_info(self, connection):
         v = connection.execute("select version()").scalar()
@@ -525,19 +480,19 @@ class PGDialect(default.DefaultDialect):
             elif attype == 'timestamp without time zone':
                 kwargs['timezone'] = False
 
-            if attype in ischema_names:
-                coltype = ischema_names[attype]
+            if attype in self.ischema_names:
+                coltype = self.ischema_names[attype]
             else:
                 if attype in domains:
                     domain = domains[attype]
-                    if domain['attype'] in ischema_names:
+                    if domain['attype'] in self.ischema_names:
                         # A table can't override whether the domain is nullable.
                         nullable = domain['nullable']
 
                         if domain['default'] and not default:
                             # It can, however, override the default value, but can't set it to null.
                             default = domain['default']
-                        coltype = ischema_names[domain['attype']]
+                        coltype = self.ischema_names[domain['attype']]
                 else:
                     coltype = None
 
@@ -693,180 +648,3 @@ class PGDialect(default.DefaultDialect):
 
         return domains
 
-
-
-class PGCompiler(compiler.DefaultCompiler):
-    operators = compiler.DefaultCompiler.operators.copy()
-    operators.update(
-        {
-            sql_operators.mod : '%%',
-            sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-            sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-            sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
-        }
-    )
-
-    functions = compiler.DefaultCompiler.functions.copy()
-    functions.update (
-        {
-            'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
-        }
-    )
-
-    def visit_sequence(self, seq):
-        if seq.optional:
-            return None
-        else:
-            return "nextval('%s')" % self.preparer.format_sequence(seq)
-
-    def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.")
-        return text.replace('%', '%%')
-
-    def limit_clause(self, select):
-        text = ""
-        if select._limit is not None:
-            text +=  " \n LIMIT " + str(select._limit)
-        if select._offset is not None:
-            if select._limit is None:
-                text += " \n LIMIT ALL"
-            text += " OFFSET " + str(select._offset)
-        return text
-
-    def get_select_precolumns(self, select):
-        if select._distinct:
-            if isinstance(select._distinct, bool):
-                return "DISTINCT "
-            elif isinstance(select._distinct, (list, tuple)):
-                return "DISTINCT ON (" + ', '.join(
-                    [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
-                )+ ") "
-            else:
-                return "DISTINCT ON (" + unicode(select._distinct) + ") "
-        else:
-            return ""
-
-    def for_update_clause(self, select):
-        if select.for_update == 'nowait':
-            return " FOR UPDATE NOWAIT"
-        else:
-            return super(PGCompiler, self).for_update_clause(select)
-
-    def _append_returning(self, text, stmt):
-        returning_cols = stmt.kwargs['postgres_returning']
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-        columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
-        text += ' RETURNING ' + string.join(columns, ', ')
-        return text
-
-    def visit_update(self, update_stmt):
-        text = super(PGCompiler, self).visit_update(update_stmt)
-        if 'postgres_returning' in update_stmt.kwargs:
-            return self._append_returning(text, update_stmt)
-        else:
-            return text
-
-    def visit_insert(self, insert_stmt):
-        text = super(PGCompiler, self).visit_insert(insert_stmt)
-        if 'postgres_returning' in insert_stmt.kwargs:
-            return self._append_returning(text, insert_stmt)
-        else:
-            return text
-
-class PGSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
-            if isinstance(column.type, PGBigInteger):
-                colspec += " BIGSERIAL"
-            else:
-                colspec += " SERIAL"
-        else:
-            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-            default = self.get_column_default_string(column)
-            if default is not None:
-                colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
-
-    def visit_sequence(self, sequence):
-        if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
-            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-    def visit_index(self, index):
-        preparer = self.preparer
-        self.append("CREATE ")
-        if index.unique:
-            self.append("UNIQUE ")
-        self.append("INDEX %s ON %s (%s)" \
-                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
-                       preparer.format_table(index.table),
-                       string.join([preparer.format_column(c) for c in index.columns], ', ')))
-        whereclause = index.kwargs.get('postgres_where', None)
-        if whereclause is not None:
-            compiler = self._compile(whereclause, None)
-            # this might belong to the compiler class
-            inlined_clause = str(compiler) % dict(
-                [(key,bind.value) for key,bind in compiler.binds.iteritems()])
-            self.append(" WHERE " + inlined_clause)
-        self.execute()
-
-class PGSchemaDropper(compiler.SchemaDropper):
-    def visit_sequence(self, sequence):
-        if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
-            self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-class PGDefaultRunner(base.DefaultRunner):
-    def __init__(self, context):
-        base.DefaultRunner.__init__(self, context)
-        # craete cursor which won't conflict with a server-side cursor
-        self.cursor = context._connection.connection.cursor()
-    
-    def get_column_default(self, column, isinsert=True):
-        if column.primary_key:
-            # pre-execute passive defaults on primary keys
-            if (isinstance(column.server_default, schema.DefaultClause) and
-                column.server_default.arg is not None):
-                return self.execute_string("select %s" % column.server_default.arg)
-            elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
-                sch = column.table.schema
-                # TODO: this has to build into the Sequence object so we can get the quoting
-                # logic from it
-                if sch is not None:
-                    exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
-                else:
-                    exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
-                return self.execute_string(exc.encode(self.dialect.encoding))
-
-        return super(PGDefaultRunner, self).get_column_default(column)
-
-    def visit_sequence(self, seq):
-        if not seq.optional:
-            return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
-        else:
-            return None
-
-class PGIdentifierPreparer(compiler.IdentifierPreparer):
-    def _unquote_identifier(self, value):
-        if value[0] == self.initial_quote:
-            value = value[1:-1].replace('""','"')
-        return value
-
-dialect = PGDialect
-dialect.statement_compiler = PGCompiler
-dialect.schemagenerator = PGSchemaGenerator
-dialect.schemadropper = PGSchemaDropper
-dialect.preparer = PGIdentifierPreparer
-dialect.defaultrunner = PGDefaultRunner
-dialect.execution_ctx_cls = PGExecutionContext
diff --git a/lib/sqlalchemy/dialects/postgres/psycopg2.py b/lib/sqlalchemy/dialects/postgres/psycopg2.py
new file mode 100644 (file)
index 0000000..5cda71b
--- /dev/null
@@ -0,0 +1,215 @@
+"""Support for the PostgreSQL database via the psycopg2 driver.
+
+Driver
+------
+
+The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
+The dialect has several behaviors  which are specifically tailored towards compatibility 
+with this module.
+
+Note that psycopg1 is **not** supported.
+
+Connecting
+----------
+
+URLs are of the form `postgres+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`.
+
+psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
+
+* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
+  this feature.  What this essentially means from a psycopg2 point of view is that the cursor is 
+  created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
+  are not immediately pre-fetched and buffered after statement execution, but are instead left 
+  on the server and only retrieved as needed.    SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
+  uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows 
+  at a time are fetched over the wire to reduce conversational overhead.
+
+Sequences/SERIAL
+----------------
+
+Postgres supports sequences, and SQLAlchemy uses these as the default means of creating
+new primary key values for integer-based primary key columns.   When creating tables, 
+SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, 
+which generates a sequence corresponding to the column and associated with it based on
+a naming convention.
+
+To specify a specific named sequence to be used for primary key generation, use the
+:func:`~sqlalchemy.schema.Sequence` construct::
+
+    Table('sometable', metadata, 
+            Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
+        )
+
+Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
+having the "last insert identifier" available, the sequence is executed independently
+beforehand and the new value is retrieved, to be used in the subsequent insert.  Note
+that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using 
+"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
+is used.
+
+Postgres 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports 
+as well.  A future release of SQLA will use this feature by default in lieu of 
+sequence pre-execution in order to retrieve new primary key values, when available.
+
+INSERT/UPDATE...RETURNING
+-------------------------
+
+The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, 
+but must be explicitly enabled on a per-statement basis::
+
+    # INSERT..RETURNING
+    result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
+        values(name='foo')
+    print result.fetchall()
+    
+    # UPDATE..RETURNING
+    result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
+        where(table.c.name=='foo').values(name='bar')
+    print result.fetchall()
+
+Indexes
+-------
+
+PostgreSQL supports partial indexes. To create them pass a postgres_where
+option to the Index constructor::
+
+  Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
+
+Transactions
+------------
+
+The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+
+
+"""
+
+import decimal, random, re, string
+
+from sqlalchemy import sql, schema, exc, util
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \
+ PGBigInteger, PGInterval
+
+class PGNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
+        else:
+            def process(value):
+                if isinstance(value, decimal.Decimal):
+                    return float(value)
+                else:
+                    return value
+            return process
+
+
+colspecs = {
+    sqltypes.Numeric : PGNumeric,
+    sqltypes.Float: sqltypes.Float,  # prevents PGNumeric from being used
+}
+
+ischema_names = {
+    'integer' : sqltypes.Integer,
+    'bigint' : PGBigInteger,
+    'smallint' : sqltypes.SmallInteger,
+    'character varying' : sqltypes.String,
+    'character' : sqltypes.CHAR,
+    'text' : sqltypes.Text,
+    'numeric' : PGNumeric,
+    'float' : sqltypes.Float,
+    'real' : sqltypes.Float,
+    'inet': PGInet,
+    'cidr': PGCidr,
+    'macaddr': PGMacAddr,
+    'double precision' : sqltypes.Float,
+    'timestamp' : sqltypes.DateTime,
+    'timestamp with time zone' : sqltypes.DateTime,
+    'timestamp without time zone' : sqltypes.DateTime,
+    'time with time zone' : sqltypes.Time,
+    'time without time zone' : sqltypes.Time,
+    'date' : sqltypes.Date,
+    'time': sqltypes.Time,
+    'bytea' : sqltypes.Binary,
+    'boolean' : sqltypes.Boolean,
+    'interval':PGInterval,
+}
+
+# TODO: filter out 'FOR UPDATE' statements
+SERVER_SIDE_CURSOR_RE = re.compile(
+    r'\s*SELECT',
+    re.I | re.UNICODE)
+
+class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext):
+    def create_cursor(self):
+        # TODO: coverage for server side cursors + select.for_update()
+        is_server_side = \
+            self.dialect.server_side_cursors and \
+            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) 
+                and not getattr(self.compiled.statement, 'for_update', False)) \
+            or \
+            (
+                (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) 
+                and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
+            )
+
+        self.__is_server_side = is_server_side
+        if is_server_side:
+            # use server-side cursors:
+            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
+            return self._connection.connection.cursor(ident)
+        else:
+            return self._connection.connection.cursor()
+
+    def get_result_proxy(self):
+        if self.__is_server_side:
+            return base.BufferedRowResultProxy(self)
+        else:
+            return base.ResultProxy(self)
+
+class Postgres_psycopg2(PGDialect):
+    driver = 'psycopg2'
+    supports_unicode_statements = False
+    default_paramstyle = 'pyformat'
+    supports_sane_multi_rowcount = False
+    execution_ctx_cls = Postgres_psycopg2ExecutionContext
+    ischema_names = ischema_names
+    
+    def __init__(self, server_side_cursors=False, **kwargs):
+        PGDialect.__init__(self, **kwargs)
+        self.server_side_cursors = server_side_cursors
+
+    @classmethod
+    def dbapi(cls):
+        psycopg = __import__('psycopg2')
+        return psycopg
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if 'port' in opts:
+            opts['port'] = int(opts['port'])
+        opts.update(url.query)
+        return ([], opts)
+
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.OperationalError):
+            return 'closed the connection' in str(e) or 'connection not open' in str(e)
+        elif isinstance(e, self.dbapi.InterfaceError):
+            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
+        elif isinstance(e, self.dbapi.ProgrammingError):
+            # yes, it really says "losed", not "closed"
+            return "losed the connection unexpectedly" in str(e)
+        else:
+            return False
+
+dialect = Postgres_psycopg2
+    
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
new file mode 100644 (file)
index 0000000..3cc0887
--- /dev/null
@@ -0,0 +1,4 @@
+from sqlalchemy.dialects.sqlite import base, pysqlite
+
+# default dialect
+base.dialect = pysqlite.dialect
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
new file mode 100644 (file)
index 0000000..a080b94
--- /dev/null
@@ -0,0 +1,339 @@
+# sqlite.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import datetime, re, time
+
+from sqlalchemy import sql, schema, exc, pool, DefaultClause
+from sqlalchemy.engine import default
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.sql import compiler, functions as sql_functions
+from types import NoneType
+
+class NumericMixin(object):
+    def bind_processor(self, dialect):
+        type_ = self.asdecimal and str or float
+        def process(value):
+            if value is not None:
+                return type_(value)
+            else:
+                return value
+        return process
+
+class SLNumeric(NumericMixin, sqltypes.Numeric):
+    pass
+
+class SLFloat(NumericMixin, sqltypes.Float):
+    pass
+
+# since SQLite has no date types, we're assuming that SQLite via ODBC
+# or JDBC would similarly have no built in date support, so the "string" based logic
+# would apply to all implementing dialects.
+class DateTimeMixin(object):
+    def _bind_processor(self, format, elements):
+        def process(value):
+            if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
+                raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
+            elif value is not None:
+                return format % tuple([getattr(value, attr, 0) for attr in elements])
+            else:
+                return None
+        return process
+
+    def _result_processor(self, fn, regexp):
+        def process(value):
+            if value is not None:
+                return fn(*[int(x or 0) for x in regexp.match(value).groups()])
+            else:
+                return None
+        return process
+
+class SLDateTime(DateTimeMixin, sqltypes.DateTime):
+    __legacy_microseconds__ = False
+
+    def bind_processor(self, dialect):
+        if self.__legacy_microseconds__:
+            return self._bind_processor(
+                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", 
+                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
+                        )
+        else:
+            return self._bind_processor(
+                        "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", 
+                        ("year", "month", "day", "hour", "minute", "second", "microsecond")
+                        )
+
+    _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
+    def result_processor(self, dialect):
+        return self._result_processor(datetime.datetime, self._reg)
+
+class SLDate(DateTimeMixin, sqltypes.Date):
+    def bind_processor(self, dialect):
+        return self._bind_processor(
+                        "%4.4d-%2.2d-%2.2d", 
+                        ("year", "month", "day")
+                )
+
+    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+    def result_processor(self, dialect):
+        return self._result_processor(datetime.date, self._reg)
+
+class SLTime(DateTimeMixin, sqltypes.Time):
+    __legacy_microseconds__ = False
+
+    def bind_processor(self, dialect):
+        if self.__legacy_microseconds__:
+            return self._bind_processor(
+                            "%2.2d:%2.2d:%2.2d.%s", 
+                            ("hour", "minute", "second", "microsecond")
+                    )
+        else:
+            return self._bind_processor(
+                            "%2.2d:%2.2d:%2.2d.%06d", 
+                            ("hour", "minute", "second", "microsecond")
+                    )
+
+    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+    def result_processor(self, dialect):
+        return self._result_processor(datetime.time, self._reg)
+
+
+class SLBoolean(sqltypes.Boolean):
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and 1 or 0
+        return process
+
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+
+
+class SQLiteCompiler(compiler.SQLCompiler):
+    functions = compiler.SQLCompiler.functions.copy()
+    functions.update (
+        {
+            sql_functions.now: 'CURRENT_TIMESTAMP',
+            sql_functions.char_length: 'length%(expr)s'
+        }
+    )
+
+    def visit_cast(self, cast, **kwargs):
+        if self.dialect.supports_cast:
+            return super(SQLiteCompiler, self).visit_cast(cast)
+        else:
+            return self.process(cast.clause)
+
+    def limit_clause(self, select):
+        text = ""
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
+                text += " \n LIMIT -1"
+            text += " OFFSET " + str(select._offset)
+        else:
+            text += " OFFSET 0"
+        return text
+
+    def for_update_clause(self, select):
+        # sqlite has no "FOR UPDATE" AFAICT
+        return ''
+
+
+class SQLiteDDLCompiler(compiler.DDLCompiler):
+
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
+
+class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_binary(self, type_):
+        return self.visit_BLOB(type_)
+    
+    def visit_CLOB(self, type_):
+        return self.visit_TEXT(type_)
+
+    def visit_NCHAR(self, type_):
+        return self.visit_CHAR(type_)
+    
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = set([
+        'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
+        'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
+        'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
+        'conflict', 'constraint', 'create', 'cross', 'current_date',
+        'current_time', 'current_timestamp', 'database', 'default',
+        'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
+        'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
+        'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
+        'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
+        'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
+        'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
+        'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
+        'plan', 'pragma', 'primary', 'query', 'raise', 'references',
+        'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
+        'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
+        'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
+        'vacuum', 'values', 'view', 'virtual', 'when', 'where',
+        ])
+
+class SQLiteDialect(default.DefaultDialect):
+    name = 'sqlite'
+    supports_alter = False
+    supports_unicode_statements = True
+    supports_default_values = True
+    supports_empty_insert = False
+    supports_cast = True
+    statement_compiler = SQLiteCompiler
+    ddl_compiler = SQLiteDDLCompiler
+    type_compiler = SQLiteTypeCompiler
+    preparer = SQLiteIdentifierPreparer
+
+    def table_names(self, connection, schema):
+        if schema is not None:
+            qschema = self.identifier_preparer.quote_identifier(schema)
+            master = '%s.sqlite_master' % qschema
+            s = ("SELECT name FROM %s "
+                 "WHERE type='table' ORDER BY name") % (master,)
+            rs = connection.execute(s)
+        else:
+            try:
+                s = ("SELECT name FROM "
+                     " (SELECT * FROM sqlite_master UNION ALL "
+                     "  SELECT * FROM sqlite_temp_master) "
+                     "WHERE type='table' ORDER BY name")
+                rs = connection.execute(s)
+            except exc.DBAPIError:
+                raise
+                s = ("SELECT name FROM sqlite_master "
+                     "WHERE type='table' ORDER BY name")
+                rs = connection.execute(s)
+
+        return [row[0] for row in rs]
+
+    def has_table(self, connection, table_name, schema=None):
+        quote = self.identifier_preparer.quote_identifier
+        if schema is not None:
+            pragma = "PRAGMA %s." % quote(schema)
+        else:
+            pragma = "PRAGMA "
+        qtable = quote(table_name)
+        cursor = connection.execute("%stable_info(%s)" % (pragma, qtable))
+        row = cursor.fetchone()
+
+        # consume remaining rows, to work around
+        # http://www.sqlite.org/cvstrac/tktview?tn=1884
+        while cursor.fetchone() is not None:
+            pass
+
+        return (row is not None)
+
+    def reflecttable(self, connection, table, include_columns):
+        preparer = self.identifier_preparer
+        if table.schema is None:
+            pragma = "PRAGMA "
+        else:
+            pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema)
+        qtable = preparer.format_table(table, False)
+
+        c = connection.execute("%stable_info(%s)" % (pragma, qtable))
+        found_table = False
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+
+            found_table = True
+            (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
+            name = re.sub(r'^\"|\"$', '', name)
+            if include_columns and name not in include_columns:
+                continue
+            match = re.match(r'(\w+)(\(.*?\))?', type_)
+            if match:
+                coltype = match.group(1)
+                args = match.group(2)
+            else:
+                coltype = "VARCHAR"
+                args = ''
+
+            try:
+                coltype = self.ischema_names[coltype]
+            except KeyError:
+                util.warn("Did not recognize type '%s' of column '%s'" %
+                          (coltype, name))
+                coltype = sqltypes.NullType
+
+            if args is not None:
+                args = re.findall(r'(\d+)', args)
+                coltype = coltype(*[int(a) for a in args])
+
+            colargs = []
+            if has_default:
+                colargs.append(DefaultClause(sql.text(default)))
+            table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
+
+        if not found_table:
+            raise exc.NoSuchTableError(table.name)
+
+        c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
+        fks = {}
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
+            tablename = re.sub(r'^\"|\"$', '', tablename)
+            localcol = re.sub(r'^\"|\"$', '', localcol)
+            remotecol = re.sub(r'^\"|\"$', '', remotecol)
+            try:
+                fk = fks[constraint_name]
+            except KeyError:
+                fk = ([], [])
+                fks[constraint_name] = fk
+
+            # look up the table based on the given table's engine, not 'self',
+            # since it could be a ProxyEngine
+            remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
+            constrained_column = table.c[localcol].name
+            refspec = ".".join([tablename, remotecol])
+            if constrained_column not in fk[0]:
+                fk[0].append(constrained_column)
+            if refspec not in fk[1]:
+                fk[1].append(refspec)
+        for name, value in fks.iteritems():
+            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True))
+        # check for UNIQUE indexes
+        c = connection.execute("%sindex_list(%s)" % (pragma, qtable))
+        unique_indexes = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            if (row[2] == 1):
+                unique_indexes.append(row[1])
+        # loop thru unique indexes for one that includes the primary key
+        for idx in unique_indexes:
+            c = connection.execute("%sindex_info(%s)" % (pragma, idx))
+            cols = []
+            while True:
+                row = c.fetchone()
+                if row is None:
+                    break
+                cols.append(row[2])
+
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
new file mode 100644 (file)
index 0000000..55ac8bd
--- /dev/null
@@ -0,0 +1,264 @@
+"""Support for the SQLite database via pysqlite.
+
+Driver
+------
+
+When using Python 2.5 and above, the built in ``sqlite3`` driver is 
+already installed and no additional installation is needed.  Otherwise,
+the ``pysqlite2`` driver needs to be present.  This is the same driver as
+``sqlite3``, just with a different name.
+
+The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
+is loaded.  This allows an explicitly installed pysqlite driver to take
+precedence over the built in one.   As with all dialects, a specific 
+DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control 
+this explicitly::
+
+    from sqlite3 import dbapi2 as sqlite
+    e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
+
+Full documentation on pysqlite is available at:
+`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
+
+Connect Strings
+---------------
+
+The file specification for the SQLite database is taken as the "database" portion of
+the URL.  Note that the format of a url is::
+
+    driver://user:pass@host/database
+    
+This means that the actual filename to be used starts with the characters to the
+**right** of the third slash.   So connecting to a relative filepath looks like::
+
+    # relative path
+    e = create_engine('sqlite:///path/to/database.db')
+    
+An absolute path, which is denoted by starting with a slash, means you need **four**
+slashes::
+
+    # absolute path
+    e = create_engine('sqlite:////path/to/database.db')
+
+To use a Windows path, regular drive specifications and backslashes can be used.  
+Double backslashes are probably needed::
+
+    # absolute path on Windows
+    e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
+
+The sqlite ``:memory:`` identifier is the default if no filepath is present.  Specify
+``sqlite://`` and nothing else::
+
+    # in-memory database
+    e = create_engine('sqlite://')
+
+Threading Behavior
+------------------
+
+Pysqlite connections do not support being moved between threads, unless
+the ``check_same_thread`` Pysqlite flag is set to ``False``.  In addition,
+when using an in-memory SQLite database, the full database exists only within 
+the scope of a single connection.  It is reported that an in-memory
+database does not support being shared between threads regardless of the 
+``check_same_thread`` flag - which means that a multithreaded
+application **cannot** share data from a ``:memory:`` database across threads
+unless access to the connection is limited to a single worker thread which communicates
+through a queueing mechanism to concurrent threads.
+
+To provide a default which accomodates SQLite's default threading capabilities
+somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
+be used by default.  This pool maintains a single SQLite connection per thread
+that is held open up to a count of five concurrent threads.  When more than five threads
+are used, a cleanup mechanism will dispose of excess unused connections.   
+
+Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
+
+ * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
+   application using an in-memory database, assuming the threading issues inherent in 
+   pysqlite are somehow accomodated for.  This pool holds persistently onto a single connection
+   which is never closed, and is returned for all requests.
+   
+ * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
+   makes use of a file-based sqlite database.  This pool disables any actual "pooling"
+   behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
+   and :func:`close()` methods.  SQLite can "connect" to a particular file with very high 
+   efficiency, so this option may actually perform better without the extra overhead
+   of :class:`SingletonThreadPool`.  NullPool will of course render a ``:memory:`` connection
+   useless since the database would be lost as soon as the connection is "returned" to the pool.
+
+Date and Time Types
+-------------------
+
+SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide 
+out of the box functionality for translating values between Python `datetime` objects
+and a SQLite-supported format.  SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
+and related types provide date formatting and parsing functionality when SQlite is used.
+The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`.
+These types represent dates and times as ISO formatted strings, which also nicely
+support ordering.   There's no reliance on typical "libc" internals for these functions
+so historical dates are fully supported.
+
+Unicode
+-------
+
+In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's 
+default behavior regarding Unicode is that all strings are returned as Python unicode objects
+in all cases.  So even if the :class:`~sqlalchemy.types.Unicode` type is 
+*not* used, you will still always receive unicode data back from a result set.  It is 
+**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
+to represent strings, since it will raise a warning if a non-unicode Python string is 
+passed from the user application.  Mixing the usage of non-unicode objects with returned unicode objects can
+quickly create confusion, particularly when using the ORM as internal data is not 
+always represented by an actual database result string.
+
+"""
+
+from sqlalchemy.dialects.sqlite.base import SLNumeric, SLFloat, SQLiteDialect, SLBoolean, SLDate, SLDateTime, SLTime
+from sqlalchemy import schema, exc, pool
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from types import NoneType
+
+class SLUnicodeMixin(object):
+    def bind_processor(self, dialect):
+        if self.convert_unicode or dialect.convert_unicode:
+            if self.assert_unicode is None:
+                assert_unicode = dialect.assert_unicode
+            else:
+                assert_unicode = self.assert_unicode
+                
+            if not assert_unicode:
+                return None
+                
+            def process(value):
+                if not isinstance(value, (unicode, NoneType)):
+                    if assert_unicode == 'warn':
+                        util.warn("Unicode type received non-unicode bind "
+                                  "param value %r" % value)
+                        return value
+                    else:
+                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+                else:
+                    return value
+            return process
+        else:
+            return None
+
+    def result_processor(self, dialect):
+        return None
+    
+class SLText(SLUnicodeMixin, sqltypes.Text):
+    pass
+
+class SLString(SLUnicodeMixin, sqltypes.String):
+    pass
+
+class SLChar(SLUnicodeMixin, sqltypes.CHAR):
+    pass
+
+
+colspecs = {
+    sqltypes.Boolean: SLBoolean,
+    sqltypes.CHAR: SLChar,
+    sqltypes.Date: SLDate,
+    sqltypes.DateTime: SLDateTime,
+    sqltypes.Float: SLFloat,
+    sqltypes.NCHAR: SLChar,
+    sqltypes.Numeric: SLNumeric,
+    sqltypes.String: SLString,
+    sqltypes.Text: SLText,
+    sqltypes.Time: SLTime,
+}
+
+ischema_names = {
+    'BLOB': sqltypes.Binary,
+    'BOOL': SLBoolean,
+    'BOOLEAN': SLBoolean,
+    'CHAR': SLChar,
+    'DATE': SLDate,
+    'DATETIME': SLDateTime,
+    'DECIMAL': SLNumeric,
+    'FLOAT': SLNumeric,
+    'INT': sqltypes.Integer,
+    'INTEGER': sqltypes.Integer,
+    'NUMERIC': SLNumeric,
+    'REAL': SLNumeric,
+    'SMALLINT': sqltypes.SmallInteger,
+    'TEXT': SLText,
+    'TIME': SLTime,
+    'TIMESTAMP': SLDateTime,
+    'VARCHAR': SLString,
+}
+
+
+class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext):
+    def post_exec(self):
+        if self.isinsert and not self.executemany:
+            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+
+
+class SQLite_pysqlite(SQLiteDialect):
+    default_paramstyle = 'qmark'
+    poolclass = pool.SingletonThreadPool
+    execution_ctx_cls = SQLite_pysqliteExecutionContext
+    driver = 'pysqlite'
+    ischema_names = ischema_names
+    
+    def __init__(self, **kwargs):
+        SQLiteDialect.__init__(self, **kwargs)
+        def vers(num):
+            return tuple([int(x) for x in num.split('.')])
+        if self.dbapi is not None:
+            sqlite_ver = self.dbapi.version_info
+            if sqlite_ver < (2, 1, '3'):
+                util.warn(
+                    ("The installed version of pysqlite2 (%s) is out-dated "
+                     "and will cause errors in some cases.  Version 2.1.3 "
+                     "or greater is recommended.") %
+                    '.'.join([str(subver) for subver in sqlite_ver]))
+            if self.dbapi.sqlite_version_info < (3, 3, 8):
+                self.supports_default_values = False
+        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
+
+    @classmethod
+    def dbapi(cls):
+        try:
+            from pysqlite2 import dbapi2 as sqlite
+        except ImportError, e:
+            try:
+                from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+            except ImportError:
+                raise e
+        return sqlite
+
+    def server_version_info(self, connection):
+        return self.dbapi.sqlite_version_info
+
+    def create_connect_args(self, url):
+        if url.username or url.password or url.host or url.port:
+            raise exc.ArgumentError(
+                "Invalid SQLite URL: %s\n"
+                "Valid SQLite URL forms are:\n"
+                " sqlite:///:memory: (or, sqlite://)\n"
+                " sqlite:///relative/path/to/file.db\n"
+                " sqlite:////absolute/path/to/file.db" % (url,))
+        filename = url.database or ':memory:'
+
+        opts = url.query.copy()
+        util.coerce_kw_type(opts, 'timeout', float)
+        util.coerce_kw_type(opts, 'isolation_level', str)
+        util.coerce_kw_type(opts, 'detect_types', int)
+        util.coerce_kw_type(opts, 'check_same_thread', bool)
+        util.coerce_kw_type(opts, 'cached_statements', int)
+
+        return ([filename], opts)
+
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
+
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
+
+dialect = SQLite_pysqlite
index b0f446598558c4eee0a060e02b7bcafd293ef1c2..6def864e89d076e6706ea47d299a40887b0138c4 100644 (file)
@@ -66,9 +66,9 @@ from sqlalchemy.engine.base import (
     ResultProxy,
     RootTransaction,
     RowProxy,
-    SchemaIterator,
     Transaction,
-    TwoPhaseTransaction
+    TwoPhaseTransaction,
+    TypeCompiler
     )
 from sqlalchemy.engine import strategies
 from sqlalchemy import util
@@ -89,9 +89,9 @@ __all__ = (
     'ResultProxy',
     'RootTransaction',
     'RowProxy',
-    'SchemaIterator',
     'Transaction',
     'TwoPhaseTransaction',
+    'TypeCompiler',
     'create_engine',
     'engine_from_config',
     )
index 39085c359617067c8eef4ae953e37389a1d0728d..f95da22731c172a1b00889e27d9e08c95690b091 100644 (file)
@@ -34,7 +34,11 @@ class Dialect(object):
     All Dialects implement the following attributes:
     
     name
-      identifying name for the dialect (i.e. 'sqlite')
+      identifying name for the dialect from a DBAPI-neutral point of view
+      (i.e. 'sqlite')
+    
+    driver
+      identitfying name for the dialect's DBAPI 
       
     positional
       True if the paramstyle for this Dialect is positional.
@@ -51,21 +55,21 @@ class Dialect(object):
       type of encoding to use for unicode, usually defaults to
       'utf-8'.
 
-    schemagenerator
-      a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates
-      schemas.
-
-    schemadropper
-      a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas.
-
     defaultrunner
       a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes
       defaults.
 
     statement_compiler
-      a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL
+      a :class:`~Compiled` class used to compile SQL
       statements
 
+    ddl_compiler
+      a :class:`~Compiled` class used to compile DDL
+      statements
+
+    execution_ctx_cls
+      a :class:`ExecutionContext` class used to handle statement execution
+
     preparer
       a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
       quote identifiers.
@@ -107,11 +111,6 @@ class Dialect(object):
 
     supports_default_values
       Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported
-
-    description_encoding
-      type of encoding to use for unicode when working with metadata
-      descriptions. If set to ``None`` no encoding will be done.
-      This usually defaults to 'utf-8'.
     """
 
     def create_connect_args(self, url):
@@ -401,7 +400,7 @@ class ExecutionContext(object):
 
 
 class Compiled(object):
-    """Represent a compiled SQL expression.
+    """Represent a compiled SQL or DDL expression.
 
     The ``__str__`` method of the ``Compiled`` object should produce
     the actual text of the statement.  ``Compiled`` objects are
@@ -411,9 +410,10 @@ class Compiled(object):
     ``Compiled`` object be dependent on the actual values of those
     bind parameters, even though it may reference those values as
     defaults.
+    
     """
 
-    def __init__(self, dialect, statement, column_keys=None, bind=None):
+    def __init__(self, dialect, statement, bind=None):
         """Construct a new ``Compiled`` object.
 
         dialect
@@ -422,41 +422,40 @@ class Compiled(object):
         statement
           ``ClauseElement`` to be compiled.
 
-        column_keys
-          a list of column names to be compiled into an INSERT or UPDATE
-          statement.
-
         bind
           Optional Engine or Connection to compile this statement against.
           
         """
         self.dialect = dialect
         self.statement = statement
-        self.column_keys = column_keys
         self.bind = bind
         self.can_execute = statement.supports_execution
 
     def compile(self):
         """Produce the internal string representation of this element."""
 
-        raise NotImplementedError()
+        self.string = self.process(self.statement)
+
+    def process(self, obj, **kwargs):
+        return obj._compiler_dispatch(self, **kwargs)
 
     def __str__(self):
-        """Return the string text of the generated SQL statement."""
+        """Return the string text of the generated SQL or DDL."""
 
-        raise NotImplementedError()
+        return self.string or ''
 
     @util.deprecated('Deprecated. Use construct_params(). '
                      '(supports Unicode key names.)')
     def get_params(self, **params):
         return self.construct_params(params)
 
-    def construct_params(self, params):
+    def construct_params(self, params=None):
         """Return the bind params for this compiled object.
 
         `params` is a dict of string/object pairs whos
         values will override bind values compiled in
         to the statement.
+        
         """
         raise NotImplementedError()
 
@@ -473,6 +472,15 @@ class Compiled(object):
 
         return self.execute(*multiparams, **params).scalar()
 
+class TypeCompiler(object):
+    """Produces DDL specification for TypeEngine objects."""
+    
+    def __init__(self, dialect):
+        self.dialect = dialect
+        
+    def process(self, type_):
+        return type_._compiler_dispatch(self)
+        
 
 class Connectable(object):
     """Interface for an object which supports execution of SQL constructs.
@@ -480,6 +488,9 @@ class Connectable(object):
     The two implementations of ``Connectable`` are :class:`Connection` and
     :class:`Engine`.
     
+    Connectable must also implement the 'dialect' member which references a
+    :class:`Dialect` instance.
+    
     """
 
     def contextual_connect(self):
@@ -813,9 +824,6 @@ class Connection(Connectable):
 
         return self.execute(object, *multiparams, **params).scalar()
 
-    def statement_compiler(self, statement, **kwargs):
-        return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
     def execute(self, object, *multiparams, **params):
         """Executes and returns a ResultProxy."""
 
@@ -860,6 +868,13 @@ class Connection(Connectable):
     def _execute_default(self, default, multiparams, params):
         return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
 
+    def _execute_ddl(self, ddl, params, multiparams):
+        context = self.__create_execution_context(
+                        compiled_ddl=ddl.compile(dialect=self.dialect), 
+                        parameters=None
+                    )
+        return self.__execute_context(context)
+
     def _execute_clauseelement(self, elem, multiparams, params):
         params = self.__distill_params(multiparams, params)
         if params:
@@ -868,7 +883,7 @@ class Connection(Connectable):
             keys = []
 
         context = self.__create_execution_context(
-                        compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), 
+                        compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), 
                         parameters=params
                     )
         return self.__execute_context(context)
@@ -877,7 +892,7 @@ class Connection(Connectable):
         """Execute a sql.Compiled object."""
 
         context = self.__create_execution_context(
-                    compiled=compiled, 
+                    compiled_sql=compiled, 
                     parameters=self.__distill_params(multiparams, params)
                 )
         return self.__execute_context(context)
@@ -900,13 +915,6 @@ class Connection(Connectable):
             self._commit_impl()
         return context.get_result_proxy()
         
-    def _execute_ddl(self, ddl, params, multiparams):
-        if params:
-            schema_item, params = params[0], params[1:]
-        else:
-            schema_item = None
-        return ddl(None, schema_item, self, *params, **multiparams)
-
     def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
             raise exc.DBAPIError.instance(None, None, e)
@@ -966,7 +974,7 @@ class Connection(Connectable):
         expression.ClauseElement: _execute_clauseelement,
         Compiled: _execute_compiled,
         schema.SchemaItem: _execute_default,
-        schema.DDL: _execute_ddl,
+        schema.DDLElement: _execute_ddl,
         basestring: _execute_text
     }
 
@@ -1126,12 +1134,16 @@ class Engine(Connectable):
     def create(self, entity, connection=None, **kwargs):
         """Create a table or index within this engine's database connection given a schema.Table object."""
 
-        self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs)
+        from sqlalchemy.engine import ddl
+
+        self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs)
 
     def drop(self, entity, connection=None, **kwargs):
         """Drop a table or index within this engine's database connection given a schema.Table object."""
 
-        self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
+        from sqlalchemy.engine import ddl
+
+        self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs)
 
     def _execute_default(self, default):
         connection = self.contextual_connect()
@@ -1212,9 +1224,6 @@ class Engine(Connectable):
         connection = self.contextual_connect(close_with_result=True)
         return connection._execute_compiled(compiled, multiparams, params)
 
-    def statement_compiler(self, statement, **kwargs):
-        return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
     def connect(self, **kwargs):
         """Return a newly allocated Connection object."""
 
@@ -1790,29 +1799,6 @@ class BufferedColumnResultProxy(ResultProxy):
             l.append(row)
         return l
 
-
-class SchemaIterator(schema.SchemaVisitor):
-    """A visitor that can gather text into a buffer and execute the contents of the buffer."""
-
-    def __init__(self, connection):
-        """Construct a new SchemaIterator."""
-        
-        self.connection = connection
-        self.buffer = StringIO.StringIO()
-
-    def append(self, s):
-        """Append content to the SchemaIterator's query buffer."""
-
-        self.buffer.write(s)
-
-    def execute(self):
-        """Execute the contents of the SchemaIterator's buffer."""
-
-        try:
-            return self.connection.execute(self.buffer.getvalue())
-        finally:
-            self.buffer.truncate(0)
-
 class DefaultRunner(schema.SchemaVisitor):
     """A visitor which accepts ColumnDefault objects, produces the
     dialect-specific SQL corresponding to their execution, and
diff --git a/lib/sqlalchemy/engine/ddl.py b/lib/sqlalchemy/engine/ddl.py
new file mode 100644 (file)
index 0000000..2fc09a2
--- /dev/null
@@ -0,0 +1,126 @@
+"""routines to handle CREATE/DROP workflow."""
+
+### TOOD: CREATE TABLE and DROP TABLE have been moved out so far.
+### Index, ForeignKey, etc. still need to move.
+
+from sqlalchemy import engine, schema
+from sqlalchemy.sql import util as sql_util
+
+class DDLBase(schema.SchemaVisitor):
+    def __init__(self, connection):
+        self.connection = connection
+    
+    def find_alterables(self, tables):
+        alterables = []
+        class FindAlterables(schema.SchemaVisitor):
+            def visit_foreign_key_constraint(self, constraint):
+                if constraint.use_alter and constraint.table in tables:
+                    alterables.append(constraint)
+        findalterables = FindAlterables()
+        for table in tables:
+            for c in table.constraints:
+                findalterables.traverse(c)
+        return alterables
+
+
+class SchemaGenerator(DDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+        super(SchemaGenerator, self).__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+        self.tables = tables and set(tables) or None
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def _can_create(self, table):
+        self.dialect.validate_identifier(table.name)
+        if table.schema:
+            self.dialect.validate_identifier(table.schema)
+        return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+    def visit_metadata(self, metadata):
+        if self.tables:
+            tables = self.tables
+        else:
+            tables = metadata.tables.values()
+        collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
+        for table in collection:
+            self.traverse_single(table)
+        if self.dialect.supports_alter:
+            for alterable in self.find_alterables(collection):
+                self.connection.execute(schema.AddForeignKey(alterable))
+
+    def visit_table(self, table):
+        for listener in table.ddl_listeners['before-create']:
+            listener('before-create', table, self.connection)
+
+        for column in table.columns:
+            if column.default is not None:
+                self.traverse_single(column.default)
+
+        self.connection.execute(schema.CreateTable(table))
+
+        if hasattr(table, 'indexes'):
+            for index in table.indexes:
+                self.traverse_single(index)
+
+        for listener in table.ddl_listeners['after-create']:
+            listener('after-create', table, self.connection)
+
+    def visit_sequence(self, sequence):
+        if self.dialect.supports_sequences:
+            if \
+                (not self.dialect.sequences_optional or not sequence.optional) and \
+                (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
+                self.connection.execute(schema.CreateSequence(sequence))
+
+    def visit_index(self, index):
+        self.connection.execute(schema.CreateIndex(index))
+
+class SchemaDropper(DDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+        super(SchemaDropper, self).__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+        self.tables = tables
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def visit_metadata(self, metadata):
+        if self.tables:
+            tables = self.tables
+        else:
+            tables = metadata.tables.values()
+        collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
+        if self.dialect.supports_alter:
+            for alterable in self.find_alterables(collection):
+                self.connection.execute(schema.DropForeignKey(alterable))
+        for table in collection:
+            self.traverse_single(table)
+
+    def _can_drop(self, table):
+        self.dialect.validate_identifier(table.name)
+        if table.schema:
+            self.dialect.validate_identifier(table.schema)
+        return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+    def visit_index(self, index):
+        self.connection.execute(schema.DropIndex(index))
+
+    def visit_table(self, table):
+        for listener in table.ddl_listeners['before-drop']:
+            listener('before-drop', table, self.connection)
+
+        for column in table.columns:
+            if column.default is not None:
+                self.traverse_single(column.default)
+        
+        self.connection.execute(schema.DropTable(table))
+
+        for listener in table.ddl_listeners['after-drop']:
+            listener('after-drop', table, self.connection)
+
+    def visit_sequence(self, sequence):
+        if self.dialect.supports_sequences:
+            if \
+                (not self.dialect.sequences_optional or not sequence.optional) and \
+                (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
+                self.connection.execute(schema.DropSequence(sequence))
index 1ffc7bb04c8f60b057d19275d8f85cf97a11546a..12b1661925da9cb5a919c0dab4193a8b98add426 100644 (file)
@@ -23,12 +23,14 @@ AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""
 
-    schemagenerator = compiler.SchemaGenerator
-    schemadropper = compiler.SchemaDropper
-    statement_compiler = compiler.DefaultCompiler
+    statement_compiler = compiler.SQLCompiler
+    ddl_compiler = compiler.DDLCompiler
+    type_compiler = compiler.GenericTypeCompiler
     preparer = compiler.IdentifierPreparer
     defaultrunner = base.DefaultRunner
     supports_alter = True
+    supports_sequences = False
+    sequences_optional = False
     supports_unicode_statements = False
     max_identifier_length = 9999
     supports_sane_rowcount = True
@@ -57,6 +59,8 @@ class DefaultDialect(base.Dialect):
             self.paramstyle = self.default_paramstyle
         self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
         self.identifier_preparer = self.preparer(self)
+        self.type_compiler = self.type_compiler(self)
+        
         if label_length and label_length > self.max_identifier_length:
             raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
         self.label_length = label_length
@@ -67,8 +71,9 @@ class DefaultDialect(base.Dialect):
         the generic object which comes from the types module.
 
         Subclasses will usually use the ``adapt_type()`` method in the
-        types module to make this job easy."""
-
+        types module to make this job easy.
+        
+        """
         if type(typeobj) is type:
             typeobj = typeobj()
         return typeobj
@@ -126,13 +131,29 @@ class DefaultDialect(base.Dialect):
 
 
 class DefaultExecutionContext(base.ExecutionContext):
-    def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
+    def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
         self.dialect = dialect
         self._connection = self.root_connection = connection
-        self.compiled = compiled
         self.engine = connection.engine
 
-        if compiled is not None:
+        if compiled_ddl is not None:
+            self.compiled = compiled = compiled_ddl
+            if not dialect.supports_unicode_statements:
+                self.statement = unicode(compiled).encode(self.dialect.encoding)
+            else:
+                self.statement = unicode(compiled)
+            self.isinsert = self.isupdate = self.executemany = False
+            self.should_autocommit = True
+            self.result_map = None
+            self.cursor = self.create_cursor()
+            self.compiled_parameters = []
+            if self.dialect.positional:
+                self.parameters = [()]
+            else:
+                self.parameters = [{}]
+        elif compiled_sql is not None:
+            self.compiled = compiled = compiled_sql
+
             # compiled clauseelement.  process bind params, process table defaults,
             # track collections used by ResultProxy to target and process results
 
@@ -172,8 +193,8 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.parameters = self.__convert_compiled_params(self.compiled_parameters)
 
         elif statement is not None:
-            # plain text statement.
-            self.result_map = None
+            # plain text statement
+            self.result_map = self.compiled = None
             self.parameters = self.__encode_param_keys(parameters)
             self.executemany = len(parameters) > 1
             if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
@@ -185,7 +206,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.should_autocommit = self.should_autocommit_text(statement)
         else:
             # no statement. used for standalone ColumnDefault execution.
-            self.statement = None
+            self.statement = self.compiled = None
             self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
             self.cursor = self.create_cursor()
 
index fa608df65ee639f4d18edb93a252ac8bd869f202..b1261da0a88be3015217e097f359cb1f55fba437 100644 (file)
@@ -201,11 +201,14 @@ class MockEngineStrategy(EngineStrategy):
 
         def create(self, entity, **kwargs):
             kwargs['checkfirst'] = False
-            self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity)
+            from sqlalchemy.engine import ddl
+            
+            ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
 
         def drop(self, entity, **kwargs):
             kwargs['checkfirst'] = False
-            self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity)
+            from sqlalchemy.engine import ddl
+            ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
 
         def execute(self, object, *multiparams, **params):
             raise NotImplementedError()
index 5c8e68ce45a6f15199b8a7303ede237ac90051ba..8000cbc6c32cc3b7b9c4cfa24b6ce0bb63e85ca6 100644 (file)
@@ -88,7 +88,15 @@ class URL(object):
         """Return the SQLAlchemy database dialect class corresponding to this URL's driver name."""
         
         try:
-            module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+            if '+' in self.drivername:
+                dialect, driver = self.drivername.split('+')
+            else:
+                dialect, driver = self.drivername, 'base'
+
+            module = __import__('sqlalchemy.dialects.%s.%s' % (dialect, driver)).dialects
+            module = getattr(module, dialect)
+            module = getattr(module, driver)
+            
             return module.dialect
         except ImportError:
             if sys.exc_info()[2].tb_next is None:
@@ -140,7 +148,7 @@ def make_url(name_or_url):
 
 def _parse_rfc1738_args(name):
     pattern = re.compile(r'''
-            (?P<name>\w+)://
+            (?P<name>[\w\+]+)://
             (?:
                 (?P<username>[^:/]*)
                 (?::(?P<password>[^/]*))?
index 3b4880403ae13cca96d8361835cfe75effc6002e..a5cb6e9d2d11fd30ef339516a75773840804d5fb 100644 (file)
@@ -478,9 +478,7 @@ def _as_declarative(cls, classname, dict_):
                                           *(tuple(cols) + tuple(args)), **table_kw)
     else:
         table = cls.__table__
-        if cols:
-            raise exceptions.ArgumentError("Can't add additional columns when specifying __table__")
-            
+
     mapper_args = getattr(cls, '__mapper_args__', {})
     if 'inherits' not in mapper_args:
         inherits = cls.__mro__[1]
@@ -532,7 +530,7 @@ def _as_declarative(cls, classname, dict_):
             mapper_args['exclude_properties'] = exclude_properties = \
                 set([c.key for c in inherited_table.c if c not in inherited_mapper._columntoproperty])
             exclude_properties.difference_update([c.key for c in cols])
-    
+        
     cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
 
 class DeclarativeMeta(type):
index 6bcc89b3c23cf96d1319beed23d6d877b7017f6f..04fc9d0ef1381691e3005c69bbc2a21e472e33ee 100644 (file)
@@ -1305,15 +1305,13 @@ class Mapper(object):
                                 if col in pks:
                                     if history.deleted:
                                         params[col._label] = prop.get_col_value(col, history.deleted[0])
-                                        hasdata = True
                                     else:
                                         # row switch logic can reach us here
                                         # remove the pk from the update params so the update doesn't
                                         # attempt to include the pk in the update statement
                                         del params[col.key]
                                         params[col._label] = prop.get_col_value(col, history.added[0])
-                                else:
-                                    hasdata = True
+                                hasdata = True
                             elif col in pks:
                                 params[col._label] = mapper._get_state_attr_by_column(state, col)
                     if hasdata:
index a4561d443df6300764f107c9cd3b5d31149252c6..0211b9707ac69ba6f16b38133a1a8d76323be4aa 100644 (file)
@@ -32,24 +32,12 @@ class ColumnProperty(StrategizedProperty):
     """Describes an object attribute that corresponds to a table column."""
 
     def __init__(self, *columns, **kwargs):
-        """Construct a ColumnProperty.
-
-        :param \*columns: The list of `columns` describes a single
-          object property. If there are multiple tables joined
-          together for the mapper, this list represents the equivalent
-          column as it appears across each table.
-
-        :param group:
-
-        :param deferred:
-
-        :param comparator_factory:
-
-        :param descriptor:
-
-        :param extension:
-
+        """The list of `columns` describes a single object
+        property. If there are multiple tables joined together for the
+        mapper, this list represents the equivalent column as it
+        appears across each table.
         """
+
         self.columns = [expression._labeled(c) for c in columns]
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
@@ -57,11 +45,6 @@ class ColumnProperty(StrategizedProperty):
         self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator)
         self.descriptor = kwargs.pop('descriptor', None)
         self.extension = kwargs.pop('extension', None)
-        if kwargs:
-            raise TypeError(
-                "%s received unexpected keyword argument(s): %s" % (
-                    self.__class__.__name__, ', '.join(sorted(kwargs.keys()))))
-
         util.set_creation_order(self)
         if self.no_instrument:
             self.strategy_class = strategies.UninstrumentedColumnLoader
@@ -1153,4 +1136,4 @@ mapper.ColumnProperty = ColumnProperty
 mapper.SynonymProperty = SynonymProperty
 mapper.ComparableProperty = ComparableProperty
 mapper.RelationProperty = RelationProperty
-mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
+mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
\ No newline at end of file
index d454bc7cff322fde73e1e84bc7443ab044532e7f..c9dc152b9845432303afb6da2d0a1cc98df91619 100644 (file)
@@ -609,9 +609,7 @@ class Column(SchemaItem, expression.ColumnClause):
                 "Unknown arguments passed to Column: " + repr(kwargs.keys()))
 
     def __str__(self):
-        if self.name is None:
-            return "(no name)"
-        elif self.table is not None:
+        if self.table is not None:
             if self.table.named_with_column:
                 return (self.table.description + "." + self.description)
             else:
@@ -619,9 +617,9 @@ class Column(SchemaItem, expression.ColumnClause):
         else:
             return self.description
 
-    @property
     def bind(self):
         return self.table.bind
+    bind = property(bind)
 
     def references(self, column):
         """Return True if this Column references the given column via foreign key."""
@@ -1884,7 +1882,30 @@ class SchemaVisitor(visitors.ClauseVisitor):
     __traverse_options__ = {'schema_visitor':True}
 
 
-class DDL(object):
+class DDLElement(expression.ClauseElement):
+    """Base class for DDL expression constructs."""
+    
+    supports_execution = True
+    _autocommit = True
+
+    def bind(self):
+        if self._bind:
+            return self._bind
+    def _set_bind(self, bind):
+        self._bind = bind
+    bind = property(bind, _set_bind)
+
+    def _generate(self):
+        s = self.__class__.__new__(self.__class__)
+        s.__dict__ = self.__dict__.copy()
+        return s
+    
+    def _compiler(self, dialect, **kw):
+        """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+        
+        return dialect.ddl_compiler(dialect, self, **kw)
+
+class DDL(DDLElement):
     """A literal DDL statement.
 
     Specifies literal SQL DDL to be executed by the database.  DDL objects can
@@ -1905,6 +1926,8 @@ class DDL(object):
       connection.execute(drop_spow)
     """
 
+    __visit_name__ = "ddl"
+    
     def __init__(self, statement, on=None, context=None, bind=None):
         """Create a DDL statement.
 
@@ -1964,6 +1987,7 @@ class DDL(object):
         self.on = on
         self.context = context or {}
         self._bind = bind
+        self.schema_item = None
 
     def execute(self, bind=None, schema_item=None):
         """Execute this DDL immediately.
@@ -1985,10 +2009,9 @@ class DDL(object):
 
         if bind is None:
             bind = _bind_or_error(self)
-        # no SQL bind params are supported
+
         if self._should_execute(None, schema_item, bind):
-            executable = expression.text(self._expand(schema_item, bind))
-            return bind.execute(executable)
+            return bind.execute(self.against(schema_item))
         else:
             bind.engine.logger.info("DDL execution skipped, criteria not met.")
 
@@ -2040,39 +2063,18 @@ class DDL(object):
                 (', '.join(schema_item.ddl_events), event))
         schema_item.ddl_listeners[event].append(self)
         return self
-
-    def bind(self):
-        """An Engine or Connection to which this DDL is bound.
-
-        This property may be assigned an ``Engine`` or ``Connection``, or
-        assigned a string or URL to automatically create a basic ``Engine``
-        for this bind with ``create_engine()``.
-        """
-        return self._bind
-
-    def _bind_to(self, bind):
-        """Bind this MetaData to an Engine, Connection, string or URL."""
-
-        global URL
-        if URL is None:
-            from sqlalchemy.engine.url import URL
-
-        if isinstance(bind, (basestring, URL)):
-            from sqlalchemy import create_engine
-            self._bind = create_engine(bind)
-        else:
-            self._bind = bind
-    bind = property(bind, _bind_to)
-
+    
+    @expression._generative
+    def against(self, schema_item):
+        """Return a copy of this DDL against a specific schema item."""
+        
+        self.schema_item = schema_item
+        
     def __call__(self, event, schema_item, bind):
         """Execute the DDL as a ddl_listener."""
 
         if self._should_execute(event, schema_item, bind):
-            statement = expression.text(self._expand(schema_item, bind))
-            return bind.execute(statement)
-
-    def _expand(self, schema_item, bind):
-        return self.statement % self._prepare_context(schema_item, bind)
+            return bind.execute(self.against(schema_item))
 
     def _should_execute(self, event, schema_item, bind):
         if self.on is None:
@@ -2082,25 +2084,6 @@ class DDL(object):
         else:
             return self.on(event, schema_item, bind)
 
-    def _prepare_context(self, schema_item, bind):
-        # table events can substitute table and schema name
-        if isinstance(schema_item, Table):
-            context = self.context.copy()
-
-            preparer = bind.dialect.identifier_preparer
-            path = preparer.format_table_seq(schema_item)
-            if len(path) == 1:
-                table, schema = path[0], ''
-            else:
-                table, schema = path[-1], path[0]
-
-            context.setdefault('table', table)
-            context.setdefault('schema', schema)
-            context.setdefault('fullname', preparer.format_table(schema_item))
-            return context
-        else:
-            return self.context
-
     def __repr__(self):
         return '<%s@%s; %s>' % (
             type(self).__name__, id(self),
@@ -2110,11 +2093,76 @@ class DDL(object):
                        if getattr(self, key)]))
 
 def _to_schema_column(element):
-    if hasattr(element, '__clause_element__'):
-        element = element.__clause_element__()
-    if not isinstance(element, Column):
-        raise exc.ArgumentError("schema.Column object expected")
-    return element
+   if hasattr(element, '__clause_element__'):
+       element = element.__clause_element__()
+   if not isinstance(element, Column):
+       raise exc.ArgumentError("schema.Column object expected")
+   return element
+
+class _CreateDropBase(DDLElement):
+    """Base class for DDL constucts that represent CREATE and DROP or equivalents.
+
+    The common theme of _CreateDropBase is a single
+    ``element`` attribute which refers to the element
+    to be created or dropped.
+    
+    """
+    
+    def __init__(self, element):
+        self.element = element
+        
+    def bind(self):
+        if self._bind:
+            return self._bind
+        if self.element:
+            e = self.element.bind
+            if e:
+                return e
+        return None
+
+    def _set_bind(self, bind):
+        self._bind = bind
+    bind = property(bind, _set_bind)
+
+class CreateTable(_CreateDropBase):
+    """Represent a CREATE TABLE statement."""
+    
+    __visit_name__ = "create_table"
+    
+class DropTable(_CreateDropBase):
+    """Represent a DROP TABLE statement."""
+
+    __visit_name__ = "drop_table"
+
+class AddForeignKey(_CreateDropBase):
+    """Represent an ALTER TABLE ADD FOREIGN KEY statement."""
+    
+    __visit_name__ = "add_foreignkey"
+    
+class DropForeignKey(_CreateDropBase):
+    """Represent an ALTER TABLE DROP FOREIGN KEY statement."""
+    
+    __visit_name__ = "drop_foreignkey"
+
+class CreateSequence(_CreateDropBase):
+    """Represent a CREATE SEQUENCE statement."""
+    
+    __visit_name__ = "create_sequence"
+
+class DropSequence(_CreateDropBase):
+    """Represent a DROP SEQUENCE statement."""
+
+    __visit_name__ = "drop_sequence"
+    
+class CreateIndex(_CreateDropBase):
+    """Represent a CREATE INDEX statement."""
+    
+    __visit_name__ = "create_index"
+
+class DropIndex(_CreateDropBase):
+    """Represent a DROP INDEX statement."""
+
+    __visit_name__ = "drop_index"
     
 def _bind_or_error(schemaitem):
     bind = schemaitem.bind
index d5c85d71d62ec597edbae5de4bb5cb7c21775c07..3e61b459b4b21104221d8e2dcd87694ccb7596b8 100644 (file)
@@ -123,7 +123,7 @@ class _CompileLabel(visitors.Visitable):
     def quote(self):
         return self.element.quote
 
-class DefaultCompiler(engine.Compiled):
+class SQLCompiler(engine.Compiled):
     """Default implementation of Compiled.
 
     Compiles ClauseElements into SQL strings.   Uses a similar visit
@@ -134,8 +134,9 @@ class DefaultCompiler(engine.Compiled):
     operators = OPERATORS
     functions = FUNCTIONS
 
-    # if we are insert/update/delete. 
-    # set to true when we visit an INSERT, UPDATE or DELETE
+    # class-level defaults which can be set at the instance
+    # level to define if this Compiled instance represents
+    # INSERT/UPDATE/DELETE
     isdelete = isinsert = isupdate = False
 
     def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
@@ -152,7 +153,9 @@ class DefaultCompiler(engine.Compiled):
           statement.
 
         """
-        engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs)
+        engine.Compiled.__init__(self, dialect, statement, **kwargs)
+
+        self.column_keys = column_keys
 
         # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
         self.inline = inline or getattr(statement, 'inline', False)
@@ -192,12 +195,6 @@ class DefaultCompiler(engine.Compiled):
         # or dialect.max_identifier_length
         self.truncated_names = {}
 
-    def compile(self):
-        self.string = self.process(self.statement)
-
-    def process(self, obj, **kwargs):
-        return obj._compiler_dispatch(self, **kwargs)
-
     def is_subquery(self):
         return len(self.stack) > 1
 
@@ -292,7 +289,7 @@ class DefaultCompiler(engine.Compiled):
         return index.name
 
     def visit_typeclause(self, typeclause, **kwargs):
-        return typeclause.type.dialect_impl(self.dialect).get_col_spec()
+        return self.dialect.type_compiler.process(typeclause.type)
 
     def post_process_text(self, text):
         return text
@@ -739,110 +736,117 @@ class DefaultCompiler(engine.Compiled):
     def visit_release_savepoint(self, savepoint_stmt):
         return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
 
-    def __str__(self):
-        return self.string or ''
-
-class DDLBase(engine.SchemaIterator):
-    def find_alterables(self, tables):
-        alterables = []
-        class FindAlterables(schema.SchemaVisitor):
-            def visit_foreign_key_constraint(self, constraint):
-                if constraint.use_alter and constraint.table in tables:
-                    alterables.append(constraint)
-        findalterables = FindAlterables()
-        for table in tables:
-            for c in table.constraints:
-                findalterables.traverse(c)
-        return alterables
 
-    def _validate_identifier(self, ident, truncate):
-        if truncate:
-            if len(ident) > self.dialect.max_identifier_length:
-                counter = getattr(self, 'counter', 0)
-                self.counter = counter + 1
-                return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+class DDLCompiler(engine.Compiled):
+    @property
+    def preparer(self):
+        return self.dialect.identifier_preparer
+        
+    def visit_ddl(self, ddl, **kwargs):
+        # table events can substitute table and schema name
+        context = ddl.context
+        if isinstance(ddl.schema_item, schema.Table):
+            context = context.copy()
+
+            preparer = self.dialect.identifier_preparer
+            path = preparer.format_table_seq(ddl.schema_item)
+            if len(path) == 1:
+                table, sch = path[0], ''
             else:
-                return ident
-        else:
-            self.dialect.validate_identifier(ident)
-            return ident
-
-
-class SchemaGenerator(DDLBase):
-    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
-        super(SchemaGenerator, self).__init__(connection, **kwargs)
-        self.checkfirst = checkfirst
-        self.tables = tables and set(tables) or None
-        self.preparer = dialect.identifier_preparer
-        self.dialect = dialect
-
-    def get_column_specification(self, column, first_pk=False):
-        raise NotImplementedError()
-
-    def _can_create(self, table):
-        self.dialect.validate_identifier(table.name)
-        if table.schema:
-            self.dialect.validate_identifier(table.schema)
-        return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+                table, sch = path[-1], path[0]
 
-    def visit_metadata(self, metadata):
-        if self.tables:
-            tables = self.tables
-        else:
-            tables = metadata.tables.values()
-        collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
-        for table in collection:
-            self.traverse_single(table)
-        if self.dialect.supports_alter:
-            for alterable in self.find_alterables(collection):
-                self.add_foreignkey(alterable)
-
-    def visit_table(self, table):
-        for listener in table.ddl_listeners['before-create']:
-            listener('before-create', table, self.connection)
+            context.setdefault('table', table)
+            context.setdefault('schema', sch)
+            context.setdefault('fullname', preparer.format_table(ddl.schema_item))
+        
+        return ddl.statement % context
 
-        for column in table.columns:
-            if column.default is not None:
-                self.traverse_single(column.default)
+    def visit_create_table(self, create):
+        table = create.element
+        preparer = self.dialect.identifier_preparer
 
-        self.append("\n" + " ".join(['CREATE'] +
-                                    table._prefixes +
+        text = "\n" + " ".join(['CREATE'] + \
+                                    table._prefixes + \
                                     ['TABLE',
-                                     self.preparer.format_table(table),
-                                     "("]))
+                                     preparer.format_table(table),
+                                     "("])
         separator = "\n"
 
         # if only one primary key, specify it along with the column
         first_pk = False
         for column in table.columns:
-            self.append(separator)
+            text += separator
             separator = ", \n"
-            self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
+            text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
             if column.primary_key:
                 first_pk = True
             for constraint in column.constraints:
-                self.traverse_single(constraint)
+                text += self.process(constraint)
 
         # On some DB order is significant: visit PK first, then the
         # other constraints (engine.ReflectionTest.testbasic failed on FB2)
         if table.primary_key:
-            self.traverse_single(table.primary_key)
+            text += self.process(table.primary_key)
+
         for constraint in [c for c in table.constraints if c is not table.primary_key]:
-            self.traverse_single(constraint)
+            text += self.process(constraint)
 
-        self.append("\n)%s\n\n" % self.post_create_table(table))
-        self.execute()
+        text += "\n)%s\n\n" % self.post_create_table(table)
+        return text
+        
+    def visit_drop_table(self, drop):
+        return "\nDROP TABLE " + self.preparer.format_table(drop.element)
+        
+    def visit_add_foreignkey(self, add):
+        return "ALTER TABLE %s ADD " % self.preparer.format_table(add.element.table) + \
+            self.define_foreign_key(add.element)
 
-        if hasattr(table, 'indexes'):
-            for index in table.indexes:
-                self.traverse_single(index)
+    def visit_drop_foreignkey(self, drop):
+        return "ALTER TABLE %s DROP CONSTRAINT %s" % (
+            self.preparer.format_table(drop.element.table),
+            self.preparer.format_constraint(drop.element))
 
-        for listener in table.ddl_listeners['after-create']:
-            listener('after-create', table, self.connection)
+    def visit_create_index(self, create):
+        index = create.element
+        preparer = self.preparer
+        text = "CREATE "
+        if index.unique:
+            text += "UNIQUE "
+        text += "INDEX %s ON %s (%s)" \
+                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+                       preparer.format_table(index.table),
+                       ', '.join(preparer.quote(c.name, c.quote)
+                                 for c in index.columns))
+        return text
+
+    def visit_drop_index(self, drop):
+        index = drop.element
+        return "\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
+
+    def get_column_specification(self, column, first_pk=False):
+        raise NotImplementedError()
 
     def post_create_table(self, table):
         return ''
 
+    def _compile(self, tocompile, parameters):
+        """compile the given string/parameters using this SchemaGenerator's dialect."""
+        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
+        compiler.compile()
+        return compiler
+
+    def _validate_identifier(self, ident, truncate):
+        if truncate:
+            if len(ident) > self.dialect.max_identifier_length:
+                counter = getattr(self, 'counter', 0)
+                self.counter = counter + 1
+                return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+            else:
+                return ident
+        else:
+            self.dialect.validate_identifier(ident)
+            return ident
+
     def get_column_default_string(self, column):
         if isinstance(column.server_default, schema.DefaultClause):
             if isinstance(column.server_default.arg, basestring):
@@ -852,149 +856,174 @@ class SchemaGenerator(DDLBase):
         else:
             return None
 
-    def _compile(self, tocompile, parameters):
-        """compile the given string/parameters using this SchemaGenerator's dialect."""
-        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
-        compiler.compile()
-        return compiler
-
     def visit_check_constraint(self, constraint):
-        self.append(", \n\t")
+        text = ", \n\t"
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " %
-                        self.preparer.format_constraint(constraint))
-        self.append(" CHECK (%s)" % constraint.sqltext)
-        self.define_constraint_deferrability(constraint)
+            text += "CONSTRAINT %s " % \
+                        self.preparer.format_constraint(constraint)
+        text += " CHECK (%s)" % constraint.sqltext
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_column_check_constraint(self, constraint):
-        self.append(" CHECK (%s)" % constraint.sqltext)
-        self.define_constraint_deferrability(constraint)
+        text = " CHECK (%s)" % constraint.sqltext
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_primary_key_constraint(self, constraint):
         if len(constraint) == 0:
-            return
-        self.append(", \n\t")
+            return ''
+        text = ", \n\t"
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
-        self.append("PRIMARY KEY ")
-        self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
-                                       for c in constraint))
-        self.define_constraint_deferrability(constraint)
+            text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+        text += "PRIMARY KEY "
+        text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+                                       for c in constraint)
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_foreign_key_constraint(self, constraint):
         if constraint.use_alter and self.dialect.supports_alter:
-            return
-        self.append(", \n\t ")
-        self.define_foreign_key(constraint)
-
-    def add_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
-        self.define_foreign_key(constraint)
-        self.execute()
+            return ''
+        
+        return ", \n\t " + self.define_foreign_key(constraint)
 
     def define_foreign_key(self, constraint):
-        preparer = self.preparer
+        preparer = self.dialect.identifier_preparer
+        text = ""
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " %
-                        preparer.format_constraint(constraint))
+            text += "CONSTRAINT %s " % \
+                        preparer.format_constraint(constraint)
         table = list(constraint.elements)[0].column.table
-        self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+        text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
             ', '.join(preparer.quote(f.parent.name, f.parent.quote)
                       for f in constraint.elements),
             preparer.format_table(table),
             ', '.join(preparer.quote(f.column.name, f.column.quote)
                       for f in constraint.elements)
-        ))
+        )
         if constraint.ondelete is not None:
-            self.append(" ON DELETE %s" % constraint.ondelete)
+            text += " ON DELETE %s" % constraint.ondelete
         if constraint.onupdate is not None:
-            self.append(" ON UPDATE %s" % constraint.onupdate)
-        self.define_constraint_deferrability(constraint)
+            text += " ON UPDATE %s" % constraint.onupdate
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def visit_unique_constraint(self, constraint):
-        self.append(", \n\t")
+        text = ", \n\t"
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " %
-                        self.preparer.format_constraint(constraint))
-        self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)))
-        self.define_constraint_deferrability(constraint)
+            text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+        text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
+        text += self.define_constraint_deferrability(constraint)
+        return text
 
     def define_constraint_deferrability(self, constraint):
+        text = ""
         if constraint.deferrable is not None:
             if constraint.deferrable:
-                self.append(" DEFERRABLE")
+                text += " DEFERRABLE"
             else:
-                self.append(" NOT DEFERRABLE")
+                text += " NOT DEFERRABLE"
         if constraint.initially is not None:
-            self.append(" INITIALLY %s" % constraint.initially)
+            text += " INITIALLY %s" % constraint.initially
+        return text
+        
+        
+# PLACEHOLDERS to get non-converted dialects to compile
+class SchemaGenerator(object):
+    pass
+    
+class SchemaDropper(object):
+    pass
+    
+    
+class GenericTypeCompiler(engine.TypeCompiler):
+    def visit_CHAR(self, type_):
+        return "CHAR" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_column(self, column):
-        pass
+    def visit_NCHAR(self, type_):
+        return "NCHAR" + (type_.length and "(%d)" % type_.length or "")
+    
+    def visit_FLOAT(self, type_):
+        return "FLOAT"
 
-    def visit_index(self, index):
-        preparer = self.preparer
-        self.append("CREATE ")
-        if index.unique:
-            self.append("UNIQUE ")
-        self.append("INDEX %s ON %s (%s)" \
-                    % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
-                       preparer.format_table(index.table),
-                       ', '.join(preparer.quote(c.name, c.quote)
-                                 for c in index.columns)))
-        self.execute()
+    def visit_NUMERIC(self, type_):
+        if type_.precision is None:
+            return "NUMERIC"
+        else:
+            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
 
+    def visit_DECIMAL(self, type_):
+        return "DECIMAL"
+        
+    def visit_INTEGER(self, type_):
+        return "INTEGER"
 
-class SchemaDropper(DDLBase):
-    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
-        super(SchemaDropper, self).__init__(connection, **kwargs)
-        self.checkfirst = checkfirst
-        self.tables = tables
-        self.preparer = dialect.identifier_preparer
-        self.dialect = dialect
+    def visit_SMALLINT(self, type_):
+        return "SMALLINT"
 
-    def visit_metadata(self, metadata):
-        if self.tables:
-            tables = self.tables
-        else:
-            tables = metadata.tables.values()
-        collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
-        if self.dialect.supports_alter:
-            for alterable in self.find_alterables(collection):
-                self.drop_foreignkey(alterable)
-        for table in collection:
-            self.traverse_single(table)
-
-    def _can_drop(self, table):
-        self.dialect.validate_identifier(table.name)
-        if table.schema:
-            self.dialect.validate_identifier(table.schema)
-        return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
-
-    def visit_index(self, index):
-        self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote))
-        self.execute()
-
-    def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
-            self.preparer.format_table(constraint.table),
-            self.preparer.format_constraint(constraint)))
-        self.execute()
-
-    def visit_table(self, table):
-        for listener in table.ddl_listeners['before-drop']:
-            listener('before-drop', table, self.connection)
+    def visit_TIMESTAMP(self, type_):
+        return 'TIMESTAMP'
 
-        for column in table.columns:
-            if column.default is not None:
-                self.traverse_single(column.default)
+    def visit_DATETIME(self, type_):
+        return "DATETIME"
 
-        self.append("\nDROP TABLE " + self.preparer.format_table(table))
-        self.execute()
+    def visit_DATE(self, type_):
+        return "DATE"
 
-        for listener in table.ddl_listeners['after-drop']:
-            listener('after-drop', table, self.connection)
+    def visit_TIME(self, type_):
+        return "TIME"
 
+    def visit_CLOB(self, type_):
+        return "CLOB"
 
+    def visit_VARCHAR(self, type_):
+        return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
+
+    def visit_BLOB(self, type_):
+        return "BLOB"
+    
+    def visit_BINARY(self, type_):
+        return "BINARY"
+        
+    def visit_BOOLEAN(self, type_):
+        return "BOOLEAN"
+    
+    def visit_TEXT(self, type_):
+        return "TEXT"
+    
+    def visit_binary(self, type_):
+        return self.visit_BINARY(type_)
+    def visit_boolean(self, type_): 
+        return self.visit_BOOLEAN(type_)
+    def visit_time(self, type_): 
+        return self.visit_TIME(type_)
+    def visit_datetime(self, type_): 
+        return self.visit_DATETIME(type_)
+    def visit_date(self, type_): 
+        return self.visit_DATE(type_)
+    def visit_small_integer(self, type_): 
+        return self.visit_SMALLINT(type_)
+    def visit_integer(self, type_): 
+        return self.visit_INTEGER(type_)
+    def visit_float(self, type_): 
+        return self.visit_FLOAT(type_)
+    def visit_numeric(self, type_): 
+        return self.visit_NUMERIC(type_)
+    def visit_string(self, type_): 
+        return self.visit_VARCHAR(type_)
+    def visit_text(self, type_): 
+        return self.visit_TEXT(type_)
+    
+    def visit_null(self, type_):
+        raise NotImplementedError("Can't generate DDL for the null type")
+        
+    def visit_type_decorator(self, type_):
+        return self.process(type_.dialect_impl(self.dialect).impl)
+        
+    def visit_user_defined(self, type_):
+        return type_.get_col_spec()
+    
 class IdentifierPreparer(object):
     """Handle quoting and case-folding of identifiers based on options."""
 
index f527c6351f7e580d2ae46b530fac382b0e1c8aa0..6be867dbf52b27ed60f841a07986997590cef99a 100644 (file)
@@ -32,10 +32,9 @@ from operator import attrgetter
 from sqlalchemy import util, exc
 from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import Visitable, cloned_traverse
-from sqlalchemy import types as sqltypes
 import operator
 
-functions, schema, sql_util = None, None, None
+functions, schema, sql_util, sqltypes = None, None, None, None
 DefaultDialect, ClauseAdapter, Annotated = None, None, None
 
 __all__ = [
@@ -974,7 +973,8 @@ class ClauseElement(Visitable):
     _annotations = {}
     supports_execution = False
     _from_objects = []
-
+    _bind = None
+    
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
 
@@ -1106,11 +1106,9 @@ class ClauseElement(Visitable):
     def bind(self):
         """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
 
-        try:
-            if self._bind is not None:
-                return self._bind
-        except AttributeError:
-            pass
+        if self._bind is not None:
+            return self._bind
+
         for f in _from_objects(self):
             if f is self:
                 continue
@@ -1139,7 +1137,7 @@ class ClauseElement(Visitable):
 
         return self.execute(*multiparams, **params).scalar()
 
-    def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False):
+    def compile(self, bind=None, dialect=None, **kw):
         """Compile this SQL expression.
 
         The return value is a :class:`~sqlalchemy.engine.Compiled` object.
@@ -1154,52 +1152,57 @@ class ClauseElement(Visitable):
           takes precedence over this ``ClauseElement``'s
           bound engine, if any.
 
-        column_keys
-          Used for INSERT and UPDATE statements, a list of
-          column names which should be present in the VALUES clause
-          of the compiled statement.  If ``None``, all columns
-          from the target table object are rendered.
-
-        compiler
-          A ``Compiled`` instance which will be used to compile
-          this expression.  This argument takes precedence
-          over the `bind` and `dialect` arguments as well as
-          this ``ClauseElement``'s bound engine, if
-          any.
-
         dialect
           A ``Dialect`` instance frmo which a ``Compiled``
           will be acquired.  This argument takes precedence
           over the `bind` argument as well as this
           ``ClauseElement``'s bound engine, if any.
 
-        inline
-          Used for INSERT statements, for a dialect which does
-          not support inline retrieval of newly generated
-          primary key columns, will force the expression used
-          to create the new primary key value to be rendered
-          inline within the INSERT statement's VALUES clause.
-          This typically refers to Sequence execution but
-          may also refer to any server-side default generation
-          function associated with a primary key `Column`.
+        \**kw
+        
+          Keyword arguments are passed along to the compiler, 
+          which can affect the string produced.
+          
+          Keywords for a statement compiler are:
+        
+          column_keys
+            Used for INSERT and UPDATE statements, a list of
+            column names which should be present in the VALUES clause
+            of the compiled statement.  If ``None``, all columns
+            from the target table object are rendered.
+
+          inline
+            Used for INSERT statements, for a dialect which does
+            not support inline retrieval of newly generated
+            primary key columns, will force the expression used
+            to create the new primary key value to be rendered
+            inline within the INSERT statement's VALUES clause.
+            This typically refers to Sequence execution but
+            may also refer to any server-side default generation
+            function associated with a primary key `Column`.
 
         """
-        if compiler is None:
-            if dialect is not None:
-                compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
-            elif bind is not None:
-                compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline)
-            elif self.bind is not None:
-                compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline)
+        
+        if not dialect:
+            if bind:
+                dialect = bind.dialect
+            elif self.bind:
+                dialect = self.bind.dialect
+                bind = self.bind
             else:
                 global DefaultDialect
                 if DefaultDialect is None:
                     from sqlalchemy.engine.default import DefaultDialect
                 dialect = DefaultDialect()
-                compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
+        compiler = self._compiler(dialect, bind=bind, **kw)
         compiler.compile()
         return compiler
-
+    
+    def _compiler(self, dialect, **kw):
+        """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+        
+        return dialect.statement_compiler(dialect, self, **kw)
+        
     def __str__(self):
         return unicode(self.compile()).encode('ascii', 'backslashreplace')
 
@@ -1230,6 +1233,12 @@ class ClauseElement(Visitable):
 class _Immutable(object):
     """mark a ClauseElement as 'immutable' when expressions are cloned."""
 
+    def unique_params(self, *optionaldict, **kwargs):
+        raise NotImplementedError("Immutable objects do not support copying")
+
+    def params(self, *optionaldict, **kwargs):
+        raise NotImplementedError("Immutable objects do not support copying")
+
     def _clone(self):
         return self
 
index a5bd497aedf9bc59f2e1c36d7a549da89d406884..4471d4fb0d3425c87f6e03ecfacca3dbc97903f3 100644 (file)
@@ -34,13 +34,10 @@ class VisitableType(type):
     """
     
     def __init__(cls, clsname, bases, clsdict):
-        if cls.__name__ == 'Visitable':
+        if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
             super(VisitableType, cls).__init__(clsname, bases, clsdict)
             return
         
-        assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \
-                                               "should define `__visit_name__`"
-        
         # set up an optimized visit dispatch function
         # for use by the compiler
         visit_name = cls.__visit_name__
index 38aba026c482b4385b523de40e5b964cddb121d0..986d3d1332d811ca026c0d4317a0758e6c1c0023 100644 (file)
@@ -15,7 +15,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType',
             'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT',
             'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB',
             'BOOLEAN', 'SMALLINT', 'DATE', 'TIME',
-            'String', 'Integer', 'SmallInteger','Smallinteger',
+            'String', 'Integer', 'SmallInteger',
             'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary',
             'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval',
             'type_map'
@@ -27,10 +27,14 @@ from decimal import Decimal as _python_Decimal
 
 from sqlalchemy import exc
 from sqlalchemy.util import pickle
+from sqlalchemy.sql.visitors import Visitable
+from sqlalchemy.sql import expression
+import sys
+expression.sqltypes = sys.modules[__name__]
 import sqlalchemy.util as util
 NoneType = type(None)
     
-class AbstractType(object):
+class AbstractType(Visitable):
 
     def __init__(self, *args, **kwargs):
         pass
@@ -89,37 +93,7 @@ class AbstractType(object):
                       for k in inspect.getargspec(self.__init__)[0][1:]))
 
 class TypeEngine(AbstractType):
-    """Base for built-in types.
-
-    May be sub-classed to create entirely new types.  Example::
-
-      import sqlalchemy.types as types
-
-      class MyType(types.TypeEngine):
-          def __init__(self, precision = 8):
-              self.precision = precision
-
-          def get_col_spec(self):
-              return "MYTYPE(%s)" % self.precision
-
-          def bind_processor(self, dialect):
-              def process(value):
-                  return value
-              return process
-
-          def result_processor(self, dialect):
-              def process(value):
-                  return value
-              return process
-
-    Once the type is made, it's immediately usable::
-
-      table = Table('foo', meta,
-          Column('id', Integer, primary_key=True),
-          Column('data', MyType(16))
-          )
-
-    """
+    """Base for built-in types."""
 
     def dialect_impl(self, dialect, **kwargs):
         try:
@@ -135,10 +109,6 @@ class TypeEngine(AbstractType):
         d['_impl_dict'] = {}
         return d
 
-    def get_col_spec(self):
-        """Return the DDL representation for this type."""
-        raise NotImplementedError()
-
     def bind_processor(self, dialect):
         """Return a conversion function for processing bind values.
 
@@ -174,6 +144,42 @@ class TypeEngine(AbstractType):
 
         return self.__class__.__mro__[0:-1]
 
+class UserDefinedType(TypeEngine):
+    """Base for user defined types.
+    
+    This should be the base of new types.  Note that
+    for most cases, :class:`TypeDecorator` is probably
+    more appropriate.
+
+      import sqlalchemy.types as types
+
+      class MyType(types.UserDefinedType):
+          def __init__(self, precision = 8):
+              self.precision = precision
+
+          def get_col_spec(self):
+              return "MYTYPE(%s)" % self.precision
+
+          def bind_processor(self, dialect):
+              def process(value):
+                  return value
+              return process
+
+          def result_processor(self, dialect):
+              def process(value):
+                  return value
+              return process
+
+    Once the type is made, it's immediately usable::
+
+      table = Table('foo', meta,
+          Column('id', Integer, primary_key=True),
+          Column('data', MyType(16))
+          )
+
+    """
+    __visit_name__ = "user_defined"
+    
 class TypeDecorator(AbstractType):
     """Allows the creation of types which add additional functionality
     to an existing type.
@@ -214,6 +220,8 @@ class TypeDecorator(AbstractType):
 
     """
 
+    __visit_name__ = "type_decorator"
+    
     def __init__(self, *args, **kwargs):
         if not hasattr(self.__class__, 'impl'):
             raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
@@ -253,9 +261,6 @@ class TypeDecorator(AbstractType):
 
         return getattr(self.impl, key)
 
-    def get_col_spec(self):
-        return self.impl.get_col_spec()
-
     def process_bind_param(self, value, dialect):
         raise NotImplementedError()
 
@@ -370,9 +375,7 @@ class NullType(TypeEngine):
     encountered during a :meth:`~sqlalchemy.Table.create` operation.
 
     """
-
-    def get_col_spec(self):
-        raise NotImplementedError()
+    __visit_name__ = 'null'
 
 NullTypeEngine = NullType
 
@@ -400,6 +403,8 @@ class String(Concatenable, TypeEngine):
 
     """
 
+    __visit_name__ = 'string'
+    
     def __init__(self, length=None, convert_unicode=False, assert_unicode=None):
         """
         Create a string-holding type.
@@ -485,6 +490,9 @@ class Text(String):
     params (and the reverse for result sets.)
 
     """
+    
+    __visit_name__ = 'text'
+    
     def dialect_impl(self, dialect, **kwargs):
         return TypeEngine.dialect_impl(self, dialect, **kwargs)
 
@@ -555,7 +563,9 @@ class UnicodeText(Text):
 
 class Integer(TypeEngine):
     """A type for ``int`` integers."""
-
+    
+    __visit_name__ = 'integer'
+    
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
@@ -568,7 +578,7 @@ class SmallInteger(Integer):
 
     """
 
-Smallinteger = SmallInteger
+    __visit_name__ = 'small_integer'
 
 class Numeric(TypeEngine):
     """A type for fixed precision numbers.
@@ -578,6 +588,8 @@ class Numeric(TypeEngine):
 
     """
 
+    __visit_name__ = 'numeric'
+    
     def __init__(self, precision=10, scale=2, asdecimal=True, length=None):
         """
         Construct a Numeric.
@@ -628,6 +640,8 @@ class Numeric(TypeEngine):
 class Float(Numeric):
     """A type for ``float`` numbers."""
 
+    __visit_name__ = 'float'
+    
     def __init__(self, precision=10, asdecimal=False, **kwargs):
         """
         Construct a Float.
@@ -652,7 +666,9 @@ class DateTime(TypeEngine):
     converted back to datetime objects when rows are returned.
 
     """
-
+    
+    __visit_name__ = 'datetime'
+    
     def __init__(self, timezone=False):
         self.timezone = timezone
 
@@ -666,6 +682,8 @@ class DateTime(TypeEngine):
 class Date(TypeEngine):
     """A type for ``datetime.date()`` objects."""
 
+    __visit_name__ = 'date'
+    
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
@@ -673,6 +691,8 @@ class Date(TypeEngine):
 class Time(TypeEngine):
     """A type for ``datetime.time()`` objects."""
 
+    __visit_name__ = 'time'
+
     def __init__(self, timezone=False):
         self.timezone = timezone
 
@@ -692,6 +712,8 @@ class Binary(TypeEngine):
 
     """
 
+    __visit_name__ = 'binary'
+
     def __init__(self, length=None):
         """
         Construct a Binary type.
@@ -806,6 +828,7 @@ class Boolean(TypeEngine):
 
     """
 
+    __visit_name__ = 'boolean'
 
 class Interval(TypeDecorator):
     """A type for ``datetime.timedelta()`` objects.
@@ -821,7 +844,7 @@ class Interval(TypeDecorator):
 
     def __init__(self):
         super(Interval, self).__init__()
-        import sqlalchemy.databases.postgres as pg
+        import sqlalchemy.dialects.postgres.base as pg
         self.__supported = {pg.PGDialect:pg.PGInterval}
         del pg
 
@@ -850,66 +873,96 @@ class Interval(TypeDecorator):
 class FLOAT(Float):
     """The SQL FLOAT type."""
 
+    __visit_name__ = 'FLOAT'
 
 class NUMERIC(Numeric):
     """The SQL NUMERIC type."""
 
+    __visit_name__ = 'NUMERIC'
+
 
 class DECIMAL(Numeric):
     """The SQL DECIMAL type."""
 
+    __visit_name__ = 'DECIMAL'
 
-class INT(Integer):
+
+class INTEGER(Integer):
     """The SQL INT or INTEGER type."""
 
+    __visit_name__ = 'INTEGER'
+INT = INTEGER
 
-INTEGER = INT
 
-class SMALLINT(Smallinteger):
+class SMALLINT(SmallInteger):
     """The SQL SMALLINT type."""
 
+    __visit_name__ = 'SMALLINT'
+
 
 class TIMESTAMP(DateTime):
     """The SQL TIMESTAMP type."""
 
+    __visit_name__ = 'TIMESTAMP'
+
 
 class DATETIME(DateTime):
     """The SQL DATETIME type."""
 
+    __visit_name__ = 'DATETIME'
+
 
 class DATE(Date):
     """The SQL DATE type."""
 
+    __visit_name__ = 'DATE'
+
 
 class TIME(Time):
     """The SQL TIME type."""
 
+    __visit_name__ = 'TIME'
 
-TEXT = Text
+class TEXT(Text):
+    """The SQL TEXT type."""
+    
+    __visit_name__ = 'TEXT'
 
 class CLOB(Text):
     """The SQL CLOB type."""
 
+    __visit_name__ = 'CLOB'
+
 
 class VARCHAR(String):
     """The SQL VARCHAR type."""
 
+    __visit_name__ = 'VARCHAR'
+
 
 class CHAR(String):
     """The SQL CHAR type."""
 
+    __visit_name__ = 'CHAR'
+
 
 class NCHAR(Unicode):
     """The SQL NCHAR type."""
 
+    __visit_name__ = 'NCHAR'
+
 
 class BLOB(Binary):
     """The SQL BLOB type."""
 
+    __visit_name__ = 'BLOB'
+
 
 class BOOLEAN(Boolean):
     """The SQL BOOLEAN type."""
 
+    __visit_name__ = 'BOOLEAN'
+
 NULLTYPE = NullType()
 
 # using VARCHAR/NCHAR so that we dont get the genericized "String"
@@ -927,3 +980,4 @@ type_map = {
     dt.timedelta : Interval,
     type(None): NullType
 }
+
index 619888135df487529947d4207d44796a96f328a6..12f155d606420157320f88616aa992bc64675621 100644 (file)
@@ -299,6 +299,7 @@ def get_cls_kwargs(cls):
         class_ = stack.pop()
         ctr = class_.__dict__.get('__init__', False)
         if not ctr or not isinstance(ctr, types.FunctionType):
+            stack.update(class_.__bases__)
             continue
         names, _, has_kw, _ = inspect.getargspec(ctr)
         args.update(names)
index 75c0918b81bc874e215ec4d4f66a2ca27c7951e9..84b2a167486a4bf6e6aa410be7e413dd34e31c8c 100644 (file)
@@ -2,12 +2,12 @@ import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exc
-from sqlalchemy.databases import postgres
+from sqlalchemy import exc, schema
+from sqlalchemy.dialects.postgres import base as postgres
 from sqlalchemy.engine.strategies import MockEngineStrategy
 from testlib import *
 from sqlalchemy.sql import table, column
-
+from testlib.testing import eq_
 
 class SequenceTest(TestBase, AssertsCompiledSQL):
     def test_basic(self):
@@ -58,6 +58,14 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
 
+    def test_create_partial_index(self):
+        tbl = Table('testtbl', MetaData(), Column('data',Integer))
+        idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+        self.assert_compile(schema.CreateIndex(idx), 
+            "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgres.dialect())
+
+
 class ReturningTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgres'
 
@@ -406,7 +414,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True)
         self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
-        self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
+        assert isinstance(table.c.answer.type, Integer)
 
     def test_domain_is_reflected(self):
         metadata = MetaData(testing.db)
@@ -418,7 +426,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True, schema='alt_schema')
         self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
-        self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
+        assert isinstance(table.c.anything.type, Integer)
 
     def test_schema_domain_is_reflected(self):
         metadata = MetaData(testing.db)
@@ -432,7 +440,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
         self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
 
-class MiscTest(TestBase, AssertsExecutionResults):
+class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     __only_on__ = 'postgres'
 
     def test_date_reflection(self):
@@ -666,17 +674,6 @@ class MiscTest(TestBase, AssertsExecutionResults):
             warnings.warn = capture_warnings._orig_showwarning
             m1.drop_all()
 
-    def test_create_partial_index(self):
-        tbl = Table('testtbl', MetaData(), Column('data',Integer))
-        idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
-
-        executed_sql = []
-        mock_strategy = MockEngineStrategy()
-        mock_conn = mock_strategy.create('postgres://', executed_sql.append)
-
-        idx.create(mock_conn)
-
-        assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10']
 
 class TimezoneTest(TestBase, AssertsExecutionResults):
     """Test timezone-aware datetimes.
index 97d12bf60357e1fb7ec5bdd08e09577ae1c0931e..29beec8d355f1f6266f006ecebbd079dab03fb7c 100644 (file)
@@ -4,7 +4,7 @@ import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
 from sqlalchemy import exc
-from sqlalchemy.databases import sqlite
+from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect
 from testlib import *
 
 
@@ -50,23 +50,18 @@ class TestTypes(TestBase, AssertsExecutionResults):
     @testing.uses_deprecated('Using String type with no length')
     def test_type_reflection(self):
         # (ask_for, roundtripped_as_if_different)
-        specs = [( String(), sqlite.SLString(), ),
-                 ( String(1), sqlite.SLString(1), ),
-                 ( String(3), sqlite.SLString(3), ),
-                 ( Text(), sqlite.SLText(), ),
-                 ( Unicode(), sqlite.SLString(), ),
-                 ( Unicode(1), sqlite.SLString(1), ),
-                 ( Unicode(3), sqlite.SLString(3), ),
-                 ( UnicodeText(), sqlite.SLText(), ),
-                 ( CLOB, sqlite.SLText(), ),
-                 ( sqlite.SLChar(1), ),
-                 ( CHAR(3), sqlite.SLChar(3), ),
-                 ( NCHAR(2), sqlite.SLChar(2), ),
-                 ( SmallInteger(), sqlite.SLSmallInteger(), ),
-                 ( sqlite.SLSmallInteger(), ),
-                 ( Binary(3), sqlite.SLBinary(), ),
-                 ( Binary(), sqlite.SLBinary() ),
-                 ( sqlite.SLBinary(3), sqlite.SLBinary(), ),
+        specs = [( String(), pysqlite_dialect.SLString(), ),
+                 ( String(1), pysqlite_dialect.SLString(1), ),
+                 ( String(3), pysqlite_dialect.SLString(3), ),
+                 ( Text(), pysqlite_dialect.SLText(), ),
+                 ( Unicode(), pysqlite_dialect.SLString(), ),
+                 ( Unicode(1), pysqlite_dialect.SLString(1), ),
+                 ( Unicode(3), pysqlite_dialect.SLString(3), ),
+                 ( UnicodeText(), pysqlite_dialect.SLText(), ),
+                 ( CLOB, pysqlite_dialect.SLText(), ),
+                 ( pysqlite_dialect.SLChar(1), ),
+                 ( CHAR(3), pysqlite_dialect.SLChar(3), ),
+                 ( NCHAR(2), pysqlite_dialect.SLChar(2), ),
                  ( NUMERIC, sqlite.SLNumeric(), ),
                  ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ),
                  ( Numeric, sqlite.SLNumeric(), ),
@@ -75,9 +70,6 @@ class TestTypes(TestBase, AssertsExecutionResults):
                  ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ),
                  ( Float, sqlite.SLNumeric(), ),
                  ( sqlite.SLNumeric(), ),
-                 ( INT, sqlite.SLInteger(), ),
-                 ( Integer, sqlite.SLInteger(), ),
-                 ( sqlite.SLInteger(), ),
                  ( TIMESTAMP, sqlite.SLDateTime(), ),
                  ( DATETIME, sqlite.SLDateTime(), ),
                  ( DateTime, sqlite.SLDateTime(), ),
@@ -113,7 +105,8 @@ class TestTypes(TestBase, AssertsExecutionResults):
             finally:
                 db.execute('DROP VIEW types_v')
         finally:
-            m.drop_all()
+            pass
+            #m.drop_all()
 
 
 class TestDefaults(TestBase, AssertsExecutionResults):
index 8274c63476ea01ea1e45de32e749bba074163429..4c929b766c6d677c67c05f1981706b8139dde70c 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy import create_engine
 from testlib.sa import MetaData, Table, Column, Integer, String
 import testlib.sa as tsa
 from testlib import TestBase, testing, engines
-
+from testlib.testing import AssertsCompiledSQL
 
 class DDLEventTest(TestBase):
     class Canary(object):
@@ -284,7 +284,7 @@ class DDLExecutionTest(TestBase):
                 r = eval(py)
                 assert list(r) == [(1,)], py
 
-class DDLTest(TestBase):
+class DDLTest(TestBase, AssertsCompiledSQL):
     def mock_engine(self):
         executor = lambda *a, **kw: None
         engine = create_engine(testing.db.name + '://',
@@ -303,20 +303,21 @@ class DDLTest(TestBase):
 
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
 
-        self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
-        self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
-        self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
-        self.assertEquals(ddl._expand(insane_schema, bind),
-                          '"s s"-"t t"-"s s"."t t"')
+        dialect = bind.dialect
+        self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect)
+        self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect)
+        self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect)
+        self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect)
 
         # overrides are used piece-meal and verbatim.
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
                   context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
-        self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
-        self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
-        self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
-        self.assertEquals(ddl._expand(insane_schema, bind),
-                          'S S-T T-"s s"."t t"-b')
+
+        self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect)
+        self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect)
+        self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect)
+        self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect)
+
     def test_filter(self):
         cx = self.mock_engine()
 
index 8e6a3df987c73f4426b2832861456ae916d63b0a..ac245981e8224f3c43e2f77e0a1b7b179e97fc1d 100644 (file)
@@ -1,6 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import StringIO, unicodedata
 import sqlalchemy as sa
+from sqlalchemy import schema
 from testlib.sa import MetaData, Table, Column
 from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
 
@@ -49,8 +50,7 @@ class ReflectionTest(TestBase, ComparesTables):
             self.assert_tables_equal(users, reflected_users)
             self.assert_tables_equal(addresses, reflected_addresses)
         finally:
-            addresses.drop()
-            users.drop()
+            meta.drop_all()
 
     def test_include_columns(self):
         meta = MetaData(testing.db)
@@ -87,20 +87,9 @@ class ReflectionTest(TestBase, ComparesTables):
         t = Table("test", meta,
             Column('foo', sa.DateTime))
 
-        import sys
-        dialect_module = sys.modules[testing.db.dialect.__module__]
-
-        # we're relying on the presence of "ischema_names" in the
-        # dialect module, else we can't test this.  we need to be able
-        # to get the dialect to not be aware of some type so we temporarily
-        # monkeypatch.  not sure what a better way for this could be,
-        # except for an established dialect hook or dialect-specific tests
-        if not hasattr(dialect_module, 'ischema_names'):
-            return
-
-        ischema_names = dialect_module.ischema_names
+        ischema_names = testing.db.dialect.ischema_names
         t.create()
-        dialect_module.ischema_names = {}
+        testing.db.dialect.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
             self.assertRaises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
@@ -112,7 +101,7 @@ class ReflectionTest(TestBase, ComparesTables):
                 assert t3.c.foo.type.__class__ == sa.types.NullType
 
         finally:
-            dialect_module.ischema_names = ischema_names
+            testing.db.dialect.ischema_names = ischema_names
             t.drop()
 
     def test_basic_override(self):
@@ -718,8 +707,9 @@ class UnicodeReflectionTest(TestBase):
             r.drop_all()
             r.create_all()
         finally:
-            metadata.drop_all()
-            bind.dispose()
+            pass
+#            metadata.drop_all()
+#            bind.dispose()
 
 
 class SchemaTest(TestBase):
@@ -733,23 +723,15 @@ class SchemaTest(TestBase):
             Column('col1', sa.Integer, primary_key=True),
             Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
             schema='someschema')
-        # ensure this doesnt crash
-        print [t for t in metadata.sorted_tables]
-        buf = StringIO.StringIO()
-        def foo(s, p=None):
-            buf.write(s)
-        gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
-        gen = gen.dialect.schemagenerator(gen.dialect, gen)
-        gen.traverse(table1)
-        gen.traverse(table2)
-        buf = buf.getvalue()
-        print buf
+
+        t1 = str(schema.CreateTable(table1).compile(bind=testing.db))
+        t2 = str(schema.CreateTable(table2).compile(bind=testing.db))
         if testing.db.dialect.preparer(testing.db.dialect).omit_schema:
-            assert buf.index("CREATE TABLE table1") > -1
-            assert buf.index("CREATE TABLE table2") > -1
+            assert t1.index("CREATE TABLE table1") > -1
+            assert t2.index("CREATE TABLE table2") > -1
         else:
-            assert buf.index("CREATE TABLE someschema.table1") > -1
-            assert buf.index("CREATE TABLE someschema.table2") > -1
+            assert t1.index("CREATE TABLE someschema.table1") > -1
+            assert t2.index("CREATE TABLE someschema.table2") > -1
 
     @testing.crashes('firebird', 'No schema support')
     @testing.fails_on('sqlite', 'FIXME: unknown')
index 3f0360e85ebf12f3a712bcd02cd0e263c87d4812..4733292483dc3adbbdb5ddec36b6aefc4798b59d 100644 (file)
@@ -1,6 +1,5 @@
 import testenv; testenv.configure_for_tests()
 import doctest, sys
-
 from testlib import sa_unittest as unittest
 
 
index c9477b5d85c1e53f51fa6dbb7b1c1dd6b90f2984..3176832f309fbd65be37676f496751eaa75900e5 100644 (file)
@@ -63,26 +63,6 @@ class DeclarativeTest(DeclarativeTestBase):
             class User(Base):
                 id = Column('id', Integer, primary_key=True)
         self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go)
-
-    def test_cant_add_columns(self):
-        t = Table('t', Base.metadata, Column('id', Integer, primary_key=True))
-        def go():
-            class User(Base):
-                __table__ = t
-                foo = Column(Integer, primary_key=True)
-        self.assertRaisesMessage(sa.exc.ArgumentError, "add additional columns", go)
-    
-    def test_undefer_column_name(self):
-        # TODO: not sure if there was an explicit
-        # test for this elsewhere
-        foo = Column(Integer)
-        eq_(str(foo), '(no name)')
-        eq_(foo.key, None)
-        eq_(foo.name, None)
-        decl._undefer_column_name('foo', foo)
-        eq_(str(foo), 'foo')
-        eq_(foo.key, 'foo')
-        eq_(foo.name, 'foo')
         
     def test_recompile_on_othermapper(self):
         """declarative version of the same test in mappers.py"""
index 553713da539fe233bcbe1d1f2d9d793463d9e331..a4363b5e5ff2a06a686a7676033e7da59c1db4bc 100644 (file)
@@ -2216,49 +2216,6 @@ class RowSwitchTest(_base.MappedTest):
         assert list(sess.execute(t5.select(), mapper=T5)) == [(2, 'some other t5')]
         assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)]
 
-class InheritingRowSwitchTest(_base.MappedTest):
-    def define_tables(self, metadata):
-        Table('parent', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('pdata', String(30))
-        )
-        Table('child', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('pid', Integer, ForeignKey('parent.id')),
-            Column('cdata', String(30))
-        )
-
-    def setup_classes(self):
-        class P(_base.ComparableEntity):
-            pass
-
-        class C(P):
-            pass
-    
-    @testing.resolve_artifact_names
-    def test_row_switch_no_child_table(self):
-        mapper(P, parent)
-        mapper(C, child, inherits=P)
-        
-        sess = create_session()
-        c1 = C(id=1, pdata='c1', cdata='c1')
-        sess.add(c1)
-        sess.flush()
-        
-        # establish a row switch between c1 and c2.
-        # c2 has no value for the "child" table
-        c2 = C(id=1, pdata='c2')
-        sess.add(c2)
-        sess.delete(c1)
-
-        self.assert_sql_execution(testing.db, sess.flush,
-            CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id",
-                {'pdata':'c2', 'parent_id':1}
-            )
-        )
-        
-        
-
 class TransactionTest(_base.MappedTest):
     __requires__ = ('deferrable_constraints',)
 
index d019aa0378bb30098d70efed426d50fb8377fc95..b03005c00efb39d4649f481777cc2fe4f4fa4566 100644 (file)
@@ -1,10 +1,13 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exc
+from sqlalchemy import exc, schema
 from testlib import *
 from testlib import config, engines
+from sqlalchemy.engine import ddl
+from testlib.testing import eq_
+from testlib.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL
 
-class ConstraintTest(TestBase, AssertsExecutionResults):
+class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
 
     def setUp(self):
         global metadata
@@ -38,7 +41,6 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
             Column('y', Integer, f)
         )
         
-        
     def test_circular_constraint(self):
         a = Table("a", metadata,
             Column('id', Integer, primary_key=True),
@@ -78,18 +80,9 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
 
         metadata.create_all()
         foo.insert().execute(id=1,x=9,y=5)
-        try:
-            foo.insert().execute(id=2,x=5,y=9)
-            assert False
-        except exc.SQLError:
-            assert True
-
+        self.assertRaises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9)
         bar.insert().execute(id=1,x=10)
-        try:
-            bar.insert().execute(id=2,x=5)
-            assert False
-        except exc.SQLError:
-            assert True
+        self.assertRaises(exc.SQLError, bar.insert().execute, id=2,x=5)
 
     def test_unique_constraint(self):
         foo = Table('foo', metadata,
@@ -106,16 +99,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         foo.insert().execute(id=2, value='value2')
         bar.insert().execute(id=1, value='a', value2='a')
         bar.insert().execute(id=2, value='a', value2='b')
-        try:
-            foo.insert().execute(id=3, value='value1')
-            assert False
-        except exc.SQLError:
-            assert True
-        try:
-            bar.insert().execute(id=3, value='a', value2='b')
-            assert False
-        except exc.SQLError:
-            assert True
+        self.assertRaises(exc.SQLError, foo.insert().execute, id=3, value='value1')
+        self.assertRaises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b')
 
     def test_index_create(self):
         employees = Table('employees', metadata,
@@ -174,35 +159,22 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         Index('sport_announcer', events.c.sport, events.c.announcer, unique=True)
         Index('idx_winners', events.c.winner)
 
-        index_names = [ ix.name for ix in events.indexes ]
-        assert 'ix_events_name' in index_names
-        assert 'ix_events_location' in index_names
-        assert 'sport_announcer' in index_names
-        assert 'idx_winners' in index_names
-        assert len(index_names) == 4
-
-        capt = []
-        connection = testing.db.connect()
-        # TODO: hacky, put a real connection proxy in
-        ex = connection._Connection__execute_context
-        def proxy(context):
-            capt.append(context.statement)
-            capt.append(repr(context.parameters))
-            ex(context)
-        connection._Connection__execute_context = proxy
-        schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection)
-        schemagen.traverse(events)
-
-        assert capt[0].strip().startswith('CREATE TABLE events')
-
-        s = set([capt[x].strip() for x in [2,4,6,8]])
-
-        assert s == set([
-            'CREATE UNIQUE INDEX ix_events_name ON events (name)',
-            'CREATE INDEX ix_events_location ON events (location)',
-            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)',
-            'CREATE INDEX idx_winners ON events (winner)'
-            ])
+        eq_(
+            set([ ix.name for ix in events.indexes ]),
+            set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners'])
+        )
+
+        self.assert_sql_execution(
+            testing.db,
+            lambda: events.create(testing.db),
+            RegexSQL("^CREATE TABLE events"),
+            AllOf(
+                ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'),
+                ExactSQL('CREATE INDEX ix_events_location ON events (location)'),
+                ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'),
+                ExactSQL('CREATE INDEX idx_winners ON events (winner)')
+            )
+        )
 
         # verify that the table is functional
         events.insert().execute(id=1, name='hockey finals', location='rink',
@@ -214,84 +186,57 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         dialect = testing.db.dialect.__class__()
         dialect.max_identifier_length = 20
 
-        schemagen = dialect.schemagenerator(dialect, None)
-        schemagen.execute = lambda : None
-
         t1 = Table("sometable", MetaData(), Column("foo", Integer))
-        schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
-        self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
-        schemagen.buffer.truncate(0)
-        schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo))
-        self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
-
-        schemadrop = dialect.schemadropper(dialect, None)
-        schemadrop.execute = lambda: None
-        self.assertRaises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
+        self.assert_compile(
+            schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)),
+            "CREATE INDEX this_name_is_t_1 ON sometable (foo)",
+            dialect=dialect
+        )
+        
+        self.assert_compile(
+            schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)),
+            "CREATE INDEX this_other_nam_1 ON sometable (foo)",
+            dialect=dialect
+        )
 
     
-class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
-    class accum(object):
-        def __init__(self):
-            self.statements = []
-        def __call__(self, sql, *a, **kw):
-            self.statements.append(sql)
-        def __contains__(self, substring):
-            for s in self.statements:
-                if substring in s:
-                    return True
-            return False
-        def __str__(self):
-            return '\n'.join([repr(x) for x in self.statements])
-        def clear(self):
-            del self.statements[:]
-
-    def setUp(self):
-        self.sql = self.accum()
-        opts = config.db_opts.copy()
-        opts['strategy'] = 'mock'
-        opts['executor'] = self.sql
-        self.engine = engines.testing_engine(options=opts)
-
+class ConstraintCompilationTest(TestBase):
 
     def _test_deferrable(self, constraint_factory):
-        meta = MetaData(self.engine)
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=True))
-        t.create()
-        assert 'DEFERRABLE' in self.sql, self.sql
-        assert 'NOT DEFERRABLE' not in self.sql, self.sql
-        self.sql.clear()
-        meta.clear()
-
-        t = Table('tbl', meta,
+                  
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'DEFERRABLE' in sql, sql
+        assert 'NOT DEFERRABLE' not in sql, sql
+        
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=False))
-        t.create()
-        assert 'NOT DEFERRABLE' in self.sql
-        self.sql.clear()
-        meta.clear()
 
-        t = Table('tbl', meta,
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'NOT DEFERRABLE' in sql
+
+
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=True, initially='IMMEDIATE'))
-        t.create()
-        assert 'NOT DEFERRABLE' not in self.sql
-        assert 'INITIALLY IMMEDIATE' in self.sql
-        self.sql.clear()
-        meta.clear()
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'NOT DEFERRABLE' not in sql
+        assert 'INITIALLY IMMEDIATE' in sql
 
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer),
                   constraint_factory(deferrable=True, initially='DEFERRED'))
-        t.create()
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
 
-        assert 'NOT DEFERRABLE' not in self.sql
-        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+        assert 'NOT DEFERRABLE' not in sql
+        assert 'INITIALLY DEFERRED' in sql
 
     def test_deferrable_pk(self):
         factory = lambda **kw: PrimaryKeyConstraint('a', **kw)
@@ -302,15 +247,15 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
         self._test_deferrable(factory)
 
     def test_deferrable_column_fk(self):
-        meta = MetaData(self.engine)
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer,
                          ForeignKey('tbl.a', deferrable=True,
                                     initially='DEFERRED')))
-        t.create()
-        assert 'DEFERRABLE' in self.sql, self.sql
-        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'DEFERRABLE' in sql
+        assert 'INITIALLY DEFERRED' in sql
 
     def test_deferrable_unique(self):
         factory = lambda **kw: UniqueConstraint('b', **kw)
@@ -321,16 +266,15 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
         self._test_deferrable(factory)
 
     def test_deferrable_column_check(self):
-        meta = MetaData(self.engine)
-        t = Table('tbl', meta,
+        t = Table('tbl', MetaData(),
                   Column('a', Integer),
                   Column('b', Integer,
                          CheckConstraint('a < b',
                                          deferrable=True,
                                          initially='DEFERRED')))
-        t.create()
-        assert 'DEFERRABLE' in self.sql, self.sql
-        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+        sql = str(schema.CreateTable(t).compile(bind=testing.db))
+        assert 'DEFERRABLE' in sql
+        assert 'INITIALLY DEFERRED' in sql
 
 
 if __name__ == "__main__":
index ea9f27cdf2dd06dcf124f774619e228d9b4b953b..671ccab1a03d79d3073dc3c1e1e1a62b0cbe029e 100644 (file)
@@ -5,7 +5,9 @@ from sqlalchemy import exc, sql, util
 from sqlalchemy.sql import table, column, label, compiler
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.engine import default
-from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
+from sqlalchemy.databases import mysql, oracle, firebird, mssql
+from sqlalchemy.dialects.sqlite import pysqlite as sqlite
+from sqlalchemy.dialects.postgres import psycopg2 as postgres
 from testlib import *
 
 table1 = table('mytable',
index 44b83defd86afe7bb184a2d30d49241c878bfff6..da649d09703c0e64a0ff1f1256b44c3ab1a80a57 100644 (file)
@@ -2,11 +2,14 @@ import decimal
 import testenv; testenv.configure_for_tests()
 import datetime, os, pickleable, re
 from sqlalchemy import *
-from sqlalchemy import exc, types, util
+from sqlalchemy import exc, types, util, schema
 from sqlalchemy.sql import operators
 from testlib.testing import eq_
 import sqlalchemy.engine.url as url
-from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
+from sqlalchemy.databases import mssql, oracle, mysql, firebird
+from sqlalchemy.dialects.sqlite import pysqlite as sqlite
+from sqlalchemy.dialects.postgres import psycopg2 as postgres
+
 from testlib import *
 
 
@@ -80,12 +83,12 @@ class AdaptTest(TestBase):
             (mysql_dialect, Unicode(), mysql.MSString),
             (mysql_dialect, UnicodeText(), mysql.MSText),
             (mysql_dialect, NCHAR(), mysql.MSNChar),
-            (postgres_dialect, String(), postgres.PGString),
-            (postgres_dialect, VARCHAR(), postgres.PGString),
-            (postgres_dialect, String(50), postgres.PGString),
-            (postgres_dialect, Unicode(), postgres.PGString),
-            (postgres_dialect, UnicodeText(), postgres.PGText),
-            (postgres_dialect, NCHAR(), postgres.PGString),
+            (postgres_dialect, String(), String),
+            (postgres_dialect, VARCHAR(), String),
+            (postgres_dialect, String(50), String),
+            (postgres_dialect, Unicode(), String),
+            (postgres_dialect, UnicodeText(), Text),
+            (postgres_dialect, NCHAR(), String),
             (firebird_dialect, String(), firebird.FBString),
             (firebird_dialect, VARCHAR(), firebird.FBString),
             (firebird_dialect, String(50), firebird.FBString),
@@ -100,11 +103,6 @@ class AdaptTest(TestBase):
 class UserDefinedTest(TestBase):
     """tests user-defined types."""
 
-    def testbasic(self):
-        print users.c.goofy4.type
-        print users.c.goofy4.type.dialect_impl(testing.db.dialect)
-        print users.c.goofy4.type.dialect_impl(testing.db.dialect).get_col_spec()
-
     def testprocessing(self):
 
         global users
@@ -135,7 +133,7 @@ class UserDefinedTest(TestBase):
     def setUpAll(self):
         global users, metadata
 
-        class MyType(types.TypeEngine):
+        class MyType(types.UserDefinedType):
             def get_col_spec(self):
                 return "VARCHAR(100)"
             def bind_processor(self, dialect):
@@ -259,7 +257,6 @@ class ColumnsTest(TestBase, AssertsExecutionResults):
             for key, value in expectedResults.items():
                 expectedResults[key] = '%s NULL' % value
 
-        print db.engine.__module__
         testTable = Table('testColumns', MetaData(db),
             Column('int_column', Integer),
             Column('smallint_column', SmallInteger),
@@ -271,7 +268,7 @@ class ColumnsTest(TestBase, AssertsExecutionResults):
         for aCol in testTable.c:
             self.assertEquals(
                 expectedResults[aCol.name],
-                db.dialect.schemagenerator(db.dialect, db, None, None).\
+                db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\
                   get_column_specification(aCol))
 
 class UnicodeTest(TestBase, AssertsExecutionResults):
@@ -469,7 +466,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
     def setUpAll(self):
         global test_table, meta
 
-        class MyCustomType(types.TypeEngine):
+        class MyCustomType(types.UserDefinedType):
             def get_col_spec(self):
                 return "INT"
             def bind_processor(self, dialect):
@@ -712,6 +709,7 @@ class NumericTest(TestBase, AssertsExecutionResults):
         from decimal import Decimal
         numeric_table.insert().execute(
             numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75)
+            
         numeric_table.insert().execute(
             numericcol=Decimal("3.5"), floatcol=Decimal("5.6"),
             ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75"))
@@ -753,7 +751,7 @@ class NumericTest(TestBase, AssertsExecutionResults):
             eq_(n2.scale, 12, dialect.name)
             
             # test colspec generates successfully using 'scale'
-            assert n2.get_col_spec()
+            assert dialect.type_compiler.process(n2)
             
             # test constructor of the dialect-specific type
             n3 = n2.__class__(scale=5)
index 4068f43d0a9f5a877a196acb9cc05e9baf8984d9..df1d37d3cd4fec647ac457965515add9647a47dc 100644 (file)
@@ -71,6 +71,10 @@ def all_dialects():
     for name in d.__all__:
         mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
         yield mod.dialect()
+    import sqlalchemy.dialects as d
+    for name in d.__all__:
+        mod = getattr(__import__('sqlalchemy.dialects.%s.base' % name).dialects, name).base
+        yield mod.dialect()
         
 class ReconnectFixture(object):
     def __init__(self, dbapi):
index fffb301f2f2996aadd3c7098e06a6c4bb37bf2ea..fb77b07bb15413d16a8ae6160fa702afe7bfc472 100644 (file)
@@ -615,14 +615,13 @@ class AssertsCompiledSQL(object):
         if dialect is None:
             dialect = getattr(self, '__dialect__', None)
 
-        if params is None:
-            keys = None
-        else:
-            keys = params.keys()
+        kw = {}
+        if params is not None:
+            kw['column_keys'] = params.keys()
 
-        c = clause.compile(column_keys=keys, dialect=dialect)
+        c = clause.compile(dialect=dialect, **kw)
 
-        print "\nSQL String:\n" + str(c) + repr(c.params)
+        print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
 
         cc = re.sub(r'\n', '', str(c))