]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
1. Module layout. sql.py and related move into a package called "sql".
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Aug 2007 21:37:48 +0000 (21:37 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Aug 2007 21:37:48 +0000 (21:37 +0000)
2. compiler names changed to be less verbose, unused classes removed.
3. Methods on Dialect which return compilers, schema generators, identifier preparers
have changed to direct class references, typically on the Dialect class itself
or optionally as attributes on an individual Dialect instance if conditional behavior is needed.
This takes away the need for Dialect subclasses to know how to instantiate these
objects, and also reduces method overhead by one call for each one.
4. as a result of 3., some internal signatures have changed for things like compiler() (now statement_compiler()), preparer(), etc., mostly in that the dialect needs to be passed explicitly as the first argument (since they are just class references now).  The compiler() method on Engine and Connection is now also named statement_compiler(), but as before does not take the dialect as an argument.

5. changed _process_row function on RowProxy to be a class reference, cuts out 50K method calls from insertspeed.py

36 files changed:
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/__init__.py [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py [moved from lib/sqlalchemy/ansisql.py with 93% similarity]
lib/sqlalchemy/sql/expression.py [moved from lib/sqlalchemy/sql.py with 93% similarity]
lib/sqlalchemy/sql/operators.py [moved from lib/sqlalchemy/operators.py with 100% similarity]
lib/sqlalchemy/sql/util.py [moved from lib/sqlalchemy/sql_util.py with 68% similarity]
lib/sqlalchemy/sql/visitors.py [new file with mode: 0644]
test/dialect/mysql.py
test/engine/reflection.py
test/orm/dynamic.py
test/orm/query.py
test/sql/constraints.py
test/sql/generative.py
test/sql/labels.py
test/sql/query.py
test/sql/quote.py
test/sql/testtypes.py
test/testlib/testing.py

index 6bf8b96e969c696c2dc1d7ad62ceaa85f20947f7..4aa773239d330d6fa4834ff3b56f929cc73ad70c 100644 (file)
@@ -347,7 +347,7 @@ class AccessDialect(ansisql.ANSIDialect):
         return names
 
 
-class AccessCompiler(ansisql.ANSICompiler):
+class AccessCompiler(compiler.DefaultCompiler):
     def visit_select_precolumns(self, select):
         """Access puts TOP, it's version of LIMIT here """
         s = select.distinct and "DISTINCT " or ""
@@ -387,7 +387,7 @@ class AccessCompiler(ansisql.ANSICompiler):
         return ''
 
 
-class AccessSchemaGenerator(ansisql.ANSISchemaGenerator):
+class AccessSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
@@ -410,7 +410,7 @@ class AccessSchemaGenerator(ansisql.ANSISchemaGenerator):
 
         return colspec
 
-class AccessSchemaDropper(ansisql.ANSISchemaDropper):
+class AccessSchemaDropper(compiler.SchemaDropper):
     def visit_index(self, index):
         self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name))
         self.execute()
@@ -418,7 +418,7 @@ class AccessSchemaDropper(ansisql.ANSISchemaDropper):
 class AccessDefaultRunner(ansisql.ANSIDefaultRunner):
     pass
 
-class AccessIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class AccessIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
         super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
 
index 307fceb483e97f2abc94709ca2bf6c522727a04a..9cccb53e85cbe27d71eb8da5fc478a338448e479 100644 (file)
@@ -7,9 +7,10 @@
 
 import warnings
 
-from sqlalchemy import util, sql, schema, ansisql, exceptions
-import sqlalchemy.engine.default as default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy.sql import compiler
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
 
 
 _initialized_kb = False
@@ -99,9 +100,9 @@ class FBExecutionContext(default.DefaultExecutionContext):
         return True
 
 
-class FBDialect(ansisql.ANSIDialect):
+class FBDialect(default.DefaultDialect):
     def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
-        ansisql.ANSIDialect.__init__(self, **kwargs)
+        default.DefaultDialect.__init__(self, **kwargs)
 
         self.type_conv = type_conv
         self.concurrency_level= concurrency_level
@@ -135,21 +136,6 @@ class FBDialect(ansisql.ANSIDialect):
     def supports_sane_rowcount(self):
         return False
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return FBCompiler(self, statement, bindparams, **kwargs)
-
-    def schemagenerator(self, *args, **kwargs):
-        return FBSchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return FBSchemaDropper(self, *args, **kwargs)
-
-    def defaultrunner(self, connection):
-        return FBDefaultRunner(connection)
-
-    def preparer(self):
-        return FBIdentifierPreparer(self)
-
     def max_identifier_length(self):
         return 31
     
@@ -307,7 +293,7 @@ class FBDialect(ansisql.ANSIDialect):
         connection.commit(True)
 
 
-class FBCompiler(ansisql.ANSICompiler):
+class FBCompiler(compiler.DefaultCompiler):
     """Firebird specific idiosincrasies"""
 
     def visit_alias(self, alias, asfrom=False, **kwargs):
@@ -346,7 +332,7 @@ class FBCompiler(ansisql.ANSICompiler):
         return ""
 
 
-class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
+class FBSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
         colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -365,13 +351,13 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
         self.execute()
 
 
-class FBSchemaDropper(ansisql.ANSISchemaDropper):
+class FBSchemaDropper(compiler.SchemaDropper):
     def visit_sequence(self, sequence):
         self.append("DROP GENERATOR %s" % sequence.name)
         self.execute()
 
 
-class FBDefaultRunner(ansisql.ANSIDefaultRunner):
+class FBDefaultRunner(base.DefaultRunner):
     def exec_default_sql(self, default):
         c = sql.select([default.arg], from_obj=["rdb$database"]).compile(bind=self.connection)
         return self.connection.execute_compiled(c).scalar()
@@ -421,7 +407,7 @@ RESERVED_WORDS = util.Set(
      "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
 
 
-class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class FBIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
         super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True)
 
@@ -430,3 +416,9 @@ class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
 
 
 dialect = FBDialect
+dialect.statement_compiler = FBCompiler
+dialect.schemagenerator = FBSchemaGenerator
+dialect.schemadropper = FBSchemaDropper
+dialect.defaultrunner = FBDefaultRunner
+dialect.preparer = FBIdentifierPreparer
+
index 21ecf15381d3e70f5ab7666e2d94d9e2e72b25b9..ceb56903a6500a7838f0152483f0917629cc98e4 100644 (file)
@@ -7,9 +7,10 @@
 
 import datetime, warnings
 
-from sqlalchemy import sql, schema, ansisql, exceptions, pool
-import sqlalchemy.engine.default as default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import sql, schema, exceptions, pool
+from sqlalchemy.sql import compiler
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
 
 
 # for offset
@@ -203,11 +204,11 @@ class InfoExecutionContext(default.DefaultExecutionContext):
     def create_cursor( self ):
         return informix_cursor( self.connection.connection )
         
-class InfoDialect(ansisql.ANSIDialect):
+class InfoDialect(default.DefaultDialect):
     
     def __init__(self, use_ansi=True,**kwargs):
         self.use_ansi = use_ansi
-        ansisql.ANSIDialect.__init__(self, **kwargs)
+        default.DefaultDialect.__init__(self, **kwargs)
         self.paramstyle = 'qmark'
 
     def dbapi(cls):
@@ -252,18 +253,6 @@ class InfoDialect(ansisql.ANSIDialect):
     def oid_column_name(self,column):
         return "rowid"
     
-    def preparer(self):
-        return InfoIdentifierPreparer(self)
-
-    def compiler(self, statement, bindparams, **kwargs):
-        return InfoCompiler(self, statement, bindparams, **kwargs)
-        
-    def schemagenerator(self, *args, **kwargs):
-        return InfoSchemaGenerator( self , *args, **kwargs)
-    
-    def schemadropper(self, *args, **params):
-        return InfoSchemaDroper( self , *args , **params)
-    
     def table_names(self, connection, schema):
         s = "select tabname from systables"
         return [row[0] for row in connection.execute(s)]
@@ -376,14 +365,14 @@ class InfoDialect(ansisql.ANSIDialect):
         for cons_name, cons_type, local_column in rows:
             table.primary_key.add( table.c[local_column] )
 
-class InfoCompiler(ansisql.ANSICompiler):
+class InfoCompiler(compiler.DefaultCompiler):
     """Info compiler modifies the lexical structure of Select statements to work under 
     non-ANSI configured Oracle databases, if the use_ansi flag is False."""
     def __init__(self, dialect, statement, parameters=None, **kwargs):
         self.limit = 0
         self.offset = 0
         
-        ansisql.ANSICompiler.__init__( self , dialect , statement , parameters , **kwargs )
+        compiler.DefaultCompiler.__init__( self , dialect , statement , parameters , **kwargs )
     
     def default_from(self):
         return " from systables where tabname = 'systables' "
@@ -416,7 +405,7 @@ class InfoCompiler(ansisql.ANSICompiler):
             if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
                 select.append_column( c )
         
-        return ansisql.ANSICompiler.visit_select(self, select)
+        return compiler.DefaultCompiler.visit_select(self, select)
         
     def limit_clause(self, select):
         return ""
@@ -437,7 +426,7 @@ class InfoCompiler(ansisql.ANSICompiler):
         elif func.name.lower() in ( 'current_timestamp' , 'now' ):
             return "CURRENT YEAR TO SECOND"
         else:
-            return ansisql.ANSICompiler.visit_function( self , func )
+            return compiler.DefaultCompiler.visit_function( self , func )
             
     def visit_clauselist(self, list):
         try:
@@ -446,7 +435,7 @@ class InfoCompiler(ansisql.ANSICompiler):
             li = [ c for c in list.clauses ]
         return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
 
-class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
+class InfoSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, first_pk=False):
         colspec = self.preparer.format_column(column)
         if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
@@ -507,7 +496,7 @@ class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
             return
         super(InfoSchemaGenerator, self).visit_index(index)
 
-class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class InfoIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
         super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
     
@@ -517,10 +506,14 @@ class InfoIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
     def _requires_quotes(self, value):
         return False
 
-class InfoSchemaDroper(ansisql.ANSISchemaDropper):
+class InfoSchemaDroper(compiler.SchemaDropper):
     def drop_foreignkey(self, constraint):
         if constraint.name is not None:
             super( InfoSchemaDroper , self ).drop_foreignkey( constraint )
 
 dialect = InfoDialect
 poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = InfoCompiler
+dialect.schemagenerator = InfoSchemaGenerator
+dialect.schemadropper = InfoSchemaDropper
+dialect.preparer = InfoIdentifierPreparer
index 619e072d911e315242260ada50d6ba2904a7eabc..0caccca9569ed132012ed555bb0c67746bd4ac75 100644 (file)
@@ -39,10 +39,10 @@ Known issues / TODO:
 
 import datetime, random, warnings, re
 
-from sqlalchemy import sql, schema, ansisql, exceptions
-import sqlalchemy.types as sqltypes
-from sqlalchemy.engine import default
-import operator, sys
+from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
     
 class MSNumeric(sqltypes.Numeric):
     def result_processor(self, dialect):
@@ -366,7 +366,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
             super(MSSQLExecutionContext_pyodbc, self).post_exec()
 
 
-class MSSQLDialect(ansisql.ANSIDialect):
+class MSSQLDialect(default.DefaultDialect):
     colspecs = {
         sqltypes.Unicode : MSNVarchar,
         sqltypes.Integer : MSInteger,
@@ -476,21 +476,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
     def supports_sane_rowcount(self):
         raise NotImplementedError()
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return MSSQLCompiler(self, statement, bindparams, **kwargs)
-
-    def schemagenerator(self, *args, **kwargs):
-        return MSSQLSchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return MSSQLSchemaDropper(self, *args, **kwargs)
-
-    def defaultrunner(self, connection, **kwargs):
-        return MSSQLDefaultRunner(connection, **kwargs)
-
-    def preparer(self):
-        return MSSQLIdentifierPreparer(self)
-
     def get_default_schema_name(self, connection):
         return self.schema_name
 
@@ -878,7 +863,7 @@ dialect_mapping = {
     }
 
 
-class MSSQLCompiler(ansisql.ANSICompiler):
+class MSSQLCompiler(compiler.DefaultCompiler):
     def __init__(self, dialect, statement, parameters, **kwargs):
         super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
         self.tablealiases = {}
@@ -931,13 +916,13 @@ class MSSQLCompiler(ansisql.ANSICompiler):
 
     def visit_binary(self, binary):
         """Move bind parameters to the right-hand side of an operator, where possible."""
-        if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
-            return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
+            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
         else:
             return super(MSSQLCompiler, self).visit_binary(binary)
 
     def label_select_column(self, select, column):
-        if isinstance(column, sql._Function):
+        if isinstance(column, expression._Function):
             return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])        
         else:
             return super(MSSQLCompiler, self).label_select_column(select, column)
@@ -963,7 +948,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
             return ""
 
 
-class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+class MSSQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
         
@@ -986,7 +971,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         
         return colspec
 
-class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
+class MSSQLSchemaDropper(compiler.SchemaDropper):
     def visit_index(self, index):
         self.append("\nDROP INDEX %s.%s" % (
             self.preparer.quote_identifier(index.table.name),
@@ -995,11 +980,11 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
         self.execute()
 
 
-class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+class MSSQLDefaultRunner(base.DefaultRunner):
     # TODO: does ms-sql have standalone sequences ?
     pass
 
-class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
         super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
 
@@ -1012,6 +997,11 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         return value
 
 dialect = MSSQLDialect
+dialect.statement_compiler = MSSQLCompiler
+dialect.schemagenerator = MSSQLSchemaGenerator
+dialect.schemadropper = MSSQLSchemaDropper
+dialect.preparer = MSSQLIdentifierPreparer
+dialect.defaultrunner = MSSQLDefaultRunner
 
 
 
index d5fd3b6c53047e1e000a877ce52877f93c006ff8..41c6ec70f3beb48cd397453019f135bcc7d100a2 100644 (file)
@@ -126,11 +126,12 @@ information affecting MySQL in SQLAlchemy.
 import re, datetime, inspect, warnings, sys
 from array import array as _array
 
-from sqlalchemy import ansisql, exceptions, logging, schema, sql, util
-from sqlalchemy import operators as sql_operators
+from sqlalchemy import exceptions, logging, schema, sql, util
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy.sql import compiler
 
 from sqlalchemy.engine import base as engine_base, default
-import sqlalchemy.types as sqltypes
+from sqlalchemy import types as sqltypes
 
 
 __all__ = (
@@ -1328,13 +1329,17 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
         return AUTOCOMMIT_RE.match(self.statement)
 
 
-class MySQLDialect(ansisql.ANSIDialect):
+class MySQLDialect(default.DefaultDialect):
     """Details of the MySQL dialect.  Not used directly in application code."""
 
     def __init__(self, use_ansiquotes=False, **kwargs):
         self.use_ansiquotes = use_ansiquotes
         kwargs.setdefault('default_paramstyle', 'format')
-        ansisql.ANSIDialect.__init__(self, **kwargs)
+        if self.use_ansiquotes:
+            self.preparer = MySQLANSIIdentifierPreparer
+        else:
+            self.preparer = MySQLIdentifierPreparer
+        default.DefaultDialect.__init__(self, **kwargs)
 
     def dbapi(cls):
         import MySQLdb as mysql
@@ -1393,7 +1398,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         return True
 
     def compiler(self, statement, bindparams, **kwargs):
-        return MySQLCompiler(self, statement, bindparams, **kwargs)
+        return MySQLCompiler(statement, bindparams, dialect=self, **kwargs)
 
     def schemagenerator(self, *args, **kwargs):
         return MySQLSchemaGenerator(self, *args, **kwargs)
@@ -1401,12 +1406,6 @@ class MySQLDialect(ansisql.ANSIDialect):
     def schemadropper(self, *args, **kwargs):
         return MySQLSchemaDropper(self, *args, **kwargs)
 
-    def preparer(self):
-        if self.use_ansiquotes:
-            return MySQLANSIIdentifierPreparer(self)
-        else:
-            return MySQLIdentifierPreparer(self)
-
     def do_executemany(self, cursor, statement, parameters,
                        context=None, **kwargs):
         rowcount = cursor.executemany(statement, parameters)
@@ -1733,8 +1732,8 @@ class _MySQLPythonRowProxy(object):
             return item
 
 
-class MySQLCompiler(ansisql.ANSICompiler):
-    operators = ansisql.ANSICompiler.operators.copy()
+class MySQLCompiler(compiler.DefaultCompiler):
+    operators = compiler.DefaultCompiler.operators.copy()
     operators.update(
         {
             sql_operators.concat_op: \
@@ -1783,7 +1782,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
 #       In older versions, the indexes must be created explicitly or the
 #       creation of foreign key constraints fails."
 
-class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+class MySQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, override_pk=False,
                                  first_pk=False):
         """Builds column DDL."""
@@ -1827,7 +1826,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         return ' '.join(table_opts)
 
 
-class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
+class MySQLSchemaDropper(compiler.SchemaDropper):
     def visit_index(self, index):
         self.append("\nDROP INDEX %s ON %s" %
                     (self.preparer.format_index(index),
@@ -2368,7 +2367,7 @@ class MySQLSchemaReflector(object):
 MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector)
 
 
-class _MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
     """MySQL-specific schema identifier configuration."""
     
     def __init__(self, dialect, **kw):
@@ -2433,3 +2432,6 @@ def _re_compile(regex):
     return re.compile(regex, re.I | re.UNICODE)
 
 dialect = MySQLDialect
+dialect.statement_compiler = MySQLCompiler
+dialect.schemagenerator = MySQLSchemaGenerator
+dialect.schemadropper = MySQLSchemaDropper
index a35db198209e24fec7d27e96d7754eee986896ef..2d8f2940f809721107ce4b6c11a89f29d6c03ecd 100644 (file)
@@ -5,11 +5,13 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 
-import re, warnings, operator, random
+import re, warnings, random
 
-from sqlalchemy import util, sql, schema, ansisql, exceptions, logging
+from sqlalchemy import util, sql, schema, exceptions, logging
 from sqlalchemy.engine import default, base
-import sqlalchemy.types as sqltypes
+from sqlalchemy.sql import compiler, visitors
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
 
 import datetime
 
@@ -229,9 +231,9 @@ class OracleExecutionContext(default.DefaultExecutionContext):
         
         return base.ResultProxy(self)
 
-class OracleDialect(ansisql.ANSIDialect):
+class OracleDialect(default.DefaultDialect):
     def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs):
-        ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs)
+        default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs)
         self.use_ansi = use_ansi
         self.threaded = threaded
         self.allow_twophase = allow_twophase
@@ -333,21 +335,6 @@ class OracleDialect(ansisql.ANSIDialect):
     def create_execution_context(self, *args, **kwargs):
         return OracleExecutionContext(self, *args, **kwargs)
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return OracleCompiler(self, statement, bindparams, **kwargs)
-
-    def preparer(self):
-        return OracleIdentifierPreparer(self)
-
-    def schemagenerator(self, *args, **kwargs):
-        return OracleSchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return OracleSchemaDropper(self, *args, **kwargs)
-
-    def defaultrunner(self, connection, **kwargs):
-        return OracleDefaultRunner(connection, **kwargs)
-
     def has_table(self, connection, table_name, schema=None):
         cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)})
         return bool( cursor.fetchone() is not None )
@@ -560,16 +547,16 @@ class _OuterJoinColumn(sql.ClauseElement):
     def __init__(self, column):
         self.column = column
         
-class OracleCompiler(ansisql.ANSICompiler):
+class OracleCompiler(compiler.DefaultCompiler):
     """Oracle compiler modifies the lexical structure of Select
     statements to work under non-ANSI configured Oracle databases, if
     the use_ansi flag is False.
     """
 
-    operators = ansisql.ANSICompiler.operators.copy()
+    operators = compiler.DefaultCompiler.operators.copy()
     operators.update(
         {
-            operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+            sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
         }
     )
 
@@ -590,13 +577,13 @@ class OracleCompiler(ansisql.ANSICompiler):
 
     def visit_join(self, join, **kwargs):
         if self.dialect.use_ansi:
-            return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
+            return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
 
         (where, parentjoin) = self.__wheres.get(join, (None, None))
 
-        class VisitOn(sql.ClauseVisitor):
+        class VisitOn(visitors.ClauseVisitor):
             def visit_binary(s, binary):
-                if binary.operator == operator.eq:
+                if binary.operator == sql_operators.eq:
                     if binary.left.table is join.right:
                         binary.left = _OuterJoinColumn(binary.left)
                     elif binary.right.table is join.right:
@@ -640,7 +627,7 @@ class OracleCompiler(ansisql.ANSICompiler):
         for c in insert.table.primary_key:
             if c.key not in self.parameters:
                 self.parameters[c.key] = None
-        return ansisql.ANSICompiler.visit_insert(self, insert)
+        return compiler.DefaultCompiler.visit_insert(self, insert)
 
     def _TODO_visit_compound_select(self, select):
         """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
@@ -672,7 +659,7 @@ class OracleCompiler(ansisql.ANSICompiler):
                 limitselect.append_whereclause("ora_rn<=%d" % select._limit)
             return self.process(limitselect, **kwargs)
         else:
-            return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
+            return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
 
     def limit_clause(self, select):
         return ""
@@ -684,7 +671,7 @@ class OracleCompiler(ansisql.ANSICompiler):
             return super(OracleCompiler, self).for_update_clause(select)
 
 
-class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
+class OracleSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
         colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -701,13 +688,13 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
             self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
 
-class OracleSchemaDropper(ansisql.ANSISchemaDropper):
+class OracleSchemaDropper(compiler.SchemaDropper):
     def visit_sequence(self, sequence):
         if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
             self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
 
-class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
+class OracleDefaultRunner(base.DefaultRunner):
     def exec_default_sql(self, default):
         c = sql.select([default.arg], from_obj=["DUAL"]).compile(bind=self.connection)
         return self.connection.execute(c).scalar()
@@ -715,10 +702,15 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
     def visit_sequence(self, seq):
         return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
 
-class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
     def format_savepoint(self, savepoint):
         name = re.sub(r'^_+', '', savepoint.ident)
         return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
 
     
 dialect = OracleDialect
+dialect.statement_compiler = OracleCompiler
+dialect.schemagenerator = OracleSchemaGenerator
+dialect.schemadropper = OracleSchemaDropper
+dialect.preparer = OracleIdentifierPreparer
+dialect.defaultrunner = OracleDefaultRunner
index 74a3ef13f222647db5e1b6e462367316fdcbf3db..29d84ad4db7f4168fcbb4e3b55d1f975ff906f6d 100644 (file)
@@ -4,11 +4,13 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import re, random, warnings, operator
+import re, random, warnings
 
-from sqlalchemy import sql, schema, ansisql, exceptions, util
+from sqlalchemy import sql, schema, exceptions, util
 from sqlalchemy.engine import base, default
-import sqlalchemy.types as sqltypes
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
 
 
 class PGInet(sqltypes.TypeEngine):
@@ -220,9 +222,9 @@ class PGExecutionContext(default.DefaultExecutionContext):
                 self._last_inserted_ids = [v for v in row]
         super(PGExecutionContext, self).post_exec()
         
-class PGDialect(ansisql.ANSIDialect):
+class PGDialect(default.DefaultDialect):
     def __init__(self, use_oids=False, server_side_cursors=False, **kwargs):
-        ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
+        default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
         self.use_oids = use_oids
         self.server_side_cursors = server_side_cursors
         self.paramstyle = 'pyformat'
@@ -249,15 +251,6 @@ class PGDialect(ansisql.ANSIDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return PGCompiler(self, statement, bindparams, **kwargs)
-
-    def schemagenerator(self, *args, **kwargs):
-        return PGSchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return PGSchemaDropper(self, *args, **kwargs)
-
     def do_begin_twophase(self, connection, xid):
         self.do_begin(connection.connection)
 
@@ -286,12 +279,6 @@ class PGDialect(ansisql.ANSIDialect):
         resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
         return [row[0] for row in resultset]
 
-    def defaultrunner(self, context, **kwargs):
-        return PGDefaultRunner(context, **kwargs)
-
-    def preparer(self):
-        return PGIdentifierPreparer(self)
-
     def get_default_schema_name(self, connection):
         if not hasattr(self, '_default_schema_name'):
             self._default_schema_name = connection.scalar("select current_schema()", None)
@@ -556,11 +543,11 @@ class PGDialect(ansisql.ANSIDialect):
         
         
         
-class PGCompiler(ansisql.ANSICompiler):
-    operators = ansisql.ANSICompiler.operators.copy()
+class PGCompiler(compiler.DefaultCompiler):
+    operators = compiler.DefaultCompiler.operators.copy()
     operators.update(
         {
-            operator.mod : '%%'
+            sql_operators.mod : '%%'
         }
     )
 
@@ -597,7 +584,7 @@ class PGCompiler(ansisql.ANSICompiler):
         else:
             return super(PGCompiler, self).for_update_clause(select)
 
-class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+class PGSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
         if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
@@ -620,13 +607,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
             self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
 
-class PGSchemaDropper(ansisql.ANSISchemaDropper):
+class PGSchemaDropper(compiler.SchemaDropper):
     def visit_sequence(self, sequence):
         if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
             self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
 
-class PGDefaultRunner(ansisql.ANSIDefaultRunner):
+class PGDefaultRunner(base.DefaultRunner):
     def get_column_default(self, column, isinsert=True):
         if column.primary_key:
             # passive defaults on primary keys have to be overridden
@@ -642,7 +629,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
                     exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
                 return self.connection.execute(exc).scalar()
 
-        return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
+        return super(PGDefaultRunner, self).get_column_default(column)
 
     def visit_sequence(self, seq):
         if not seq.optional:
@@ -650,7 +637,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
         else:
             return None
 
-class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
     def _fold_identifier_case(self, value):
         return value.lower()
 
@@ -660,3 +647,8 @@ class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         return value
 
 dialect = PGDialect
+dialect.statement_compiler = PGCompiler
+dialect.schemagenerator = PGSchemaGenerator
+dialect.schemadropper = PGSchemaDropper
+dialect.preparer = PGIdentifierPreparer
+dialect.defaultrunner = PGDefaultRunner
index d96422236ffb939786204c40bff1f609c9ca620e..c2aced4d03f00fe0004dee4c0a7aa194ea23f203 100644 (file)
@@ -7,11 +7,12 @@
 
 import re
 
-from sqlalchemy import schema, ansisql, exceptions, pool, PassiveDefault
-import sqlalchemy.engine.default as default
+from sqlalchemy import schema, exceptions, pool, PassiveDefault
+from sqlalchemy.engine import default
 import sqlalchemy.types as sqltypes
 import datetime,time, warnings
 import sqlalchemy.util as util
+from sqlalchemy.sql import compiler
 
 
 SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE)
@@ -172,10 +173,10 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
     def is_select(self):
         return SELECT_REGEXP.match(self.statement)
         
-class SQLiteDialect(ansisql.ANSIDialect):
+class SQLiteDialect(default.DefaultDialect):
     
     def __init__(self, **kwargs):
-        ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
+        default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs)
         def vers(num):
             return tuple([int(x) for x in num.split('.')])
         if self.dbapi is not None:
@@ -195,24 +196,12 @@ class SQLiteDialect(ansisql.ANSIDialect):
         return sqlite
     dbapi = classmethod(dbapi)
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return SQLiteCompiler(self, statement, bindparams, **kwargs)
-
-    def schemagenerator(self, *args, **kwargs):
-        return SQLiteSchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return SQLiteSchemaDropper(self, *args, **kwargs)
-
     def server_version_info(self, connection):
         return self.dbapi.sqlite_version_info
 
     def supports_alter(self):
         return False
 
-    def preparer(self):
-        return SQLiteIdentifierPreparer(self)
-
     def create_connect_args(self, url):
         filename = url.database or ':memory:'
 
@@ -255,7 +244,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
         return (row is not None)
 
     def reflecttable(self, connection, table, include_columns):
-        c = connection.execute("PRAGMA table_info(%s)" % self.preparer().format_table(table), {})
+        c = connection.execute("PRAGMA table_info(%s)" % self.identifier_preparer.format_table(table), {})
         found_table = False
         while True:
             row = c.fetchone()
@@ -295,7 +284,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
 
-        c = connection.execute("PRAGMA foreign_key_list(%s)" % self.preparer().format_table(table), {})
+        c = connection.execute("PRAGMA foreign_key_list(%s)" % self.identifier_preparer.format_table(table), {})
         fks = {}
         while True:
             row = c.fetchone()
@@ -324,7 +313,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
         for name, value in fks.iteritems():
             table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))
         # check for UNIQUE indexes
-        c = connection.execute("PRAGMA index_list(%s)" % self.preparer().format_table(table), {})
+        c = connection.execute("PRAGMA index_list(%s)" % self.identifier_preparer.format_table(table), {})
         unique_indexes = []
         while True:
             row = c.fetchone()
@@ -343,7 +332,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
                 cols.append(row[2])
                 col = table.columns[row[2]]
 
-class SQLiteCompiler(ansisql.ANSICompiler):
+class SQLiteCompiler(compiler.DefaultCompiler):
     def visit_cast(self, cast):
         if self.dialect.supports_cast:
             return super(SQLiteCompiler, self).visit_cast(cast)
@@ -369,7 +358,8 @@ class SQLiteCompiler(ansisql.ANSICompiler):
         # sqlite has no "FOR UPDATE" AFAICT
         return ''
 
-class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
+
+class SQLiteSchemaGenerator(compiler.SchemaGenerator):
 
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
@@ -391,12 +381,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     #    else:
     #        super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
 
-class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
+class SQLiteSchemaDropper(compiler.SchemaDropper):
     pass
 
-class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
         super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
 
 dialect = SQLiteDialect
 dialect.poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = SQLiteCompiler
+dialect.schemagenerator = SQLiteSchemaGenerator
+dialect.schemadropper = SQLiteSchemaDropper
+dialect.preparer = SQLiteIdentifierPreparer
+
index faeb00cc98ca0942b77644a0e35a0448027f5c28..553c8df845c28a8646959b585ec3d32127f3f9b2 100644 (file)
@@ -8,7 +8,8 @@
 higher-level statement-construction, connection-management, 
 execution and result contexts."""
 
-from sqlalchemy import exceptions, sql, schema, util, types, logging
+from sqlalchemy import exceptions, schema, util, types, logging
+from sqlalchemy.sql import expression, visitors
 import StringIO, sys
 
 
@@ -35,6 +36,22 @@ class Dialect(object):
 
       encoding
         type of encoding to use for unicode, usually defaults to 'utf-8'
+
+      schemagenerator
+        a [sqlalchemy.schema#SchemaVisitor] class which generates schemas.
+
+      schemadropper
+        a [sqlalchemy.schema#SchemaVisitor] class which drops schemas.
+
+      defaultrunner
+        a [sqlalchemy.schema#SchemaVisitor] class which executes defaults.
+
+      statement_compiler
+        a [sqlalchemy.engine.base#Compiled] class used to compile SQL statements
+
+      preparer
+        a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to quote
+        identifiers.
     """
 
     def create_connect_args(self, url):
@@ -105,48 +122,6 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def schemagenerator(self, connection, **kwargs):
-        """Return a [sqlalchemy.schema#SchemaVisitor] instance that can generate schemas.
-
-            connection
-                a [sqlalchemy.engine#Connection] to use for statement execution
-                
-        `schemagenerator()` is called via the `create()` method on Table,
-        Index, and others.
-        """
-
-        raise NotImplementedError()
-
-    def schemadropper(self, connection, **kwargs):
-        """Return a [sqlalchemy.schema#SchemaVisitor] instance that can drop schemas.
-
-            connection
-                a [sqlalchemy.engine#Connection] to use for statement execution
-
-        `schemadropper()` is called via the `drop()` method on Table,
-        Index, and others.
-        """
-
-        raise NotImplementedError()
-
-    def defaultrunner(self, execution_context):
-        """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
-        
-            execution_context
-                a [sqlalchemy.engine#ExecutionContext] to use for statement execution
-        
-        """
-
-        raise NotImplementedError()
-
-    def compiler(self, statement, parameters):
-        """Return a [sqlalchemy.sql#Compiled] object for the given statement/parameters.
-
-        The returned object is usually a subclass of [sqlalchemy.ansisql#ANSICompiler].
-
-        """
-
-        raise NotImplementedError()
 
     def server_version_info(self, connection):
         """Return a tuple of the database's version number."""
@@ -266,16 +241,6 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-
-    def compile(self, clauseelement, parameters=None):
-        """Compile the given [sqlalchemy.sql#ClauseElement] using this Dialect.
-        
-        Returns [sqlalchemy.sql#Compiled].  A convenience method which 
-        flips around the compile() call on ``ClauseElement``.
-        """
-
-        return clauseelement.compile(dialect=self, parameters=parameters)
-
     def is_disconnect(self, e):
         """Return True if the given DBAPI error indicates an invalid connection"""
         
@@ -304,7 +269,7 @@ class ExecutionContext(object):
             DBAPI cursor procured from the connection
             
         compiled
-            if passed to constructor, sql.Compiled object being executed
+            if passed to constructor, sqlalchemy.engine.base.Compiled object being executed
         
         statement
             string version of the statement to be executed.  Is either
@@ -439,6 +404,9 @@ class Compiled(object):
     def __init__(self, dialect, statement, parameters, bind=None):
         """Construct a new ``Compiled`` object.
 
+        dialect
+          ``Dialect`` to compile against.
+          
         statement
           ``ClauseElement`` to be compiled.
 
@@ -724,8 +692,8 @@ class Connection(Connectable):
     def scalar(self, object, *multiparams, **params):
         return self.execute(object, *multiparams, **params).scalar()
 
-    def compiler(self, statement, parameters, **kwargs):
-        return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+    def statement_compiler(self, statement, parameters, **kwargs):
+        return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
 
     def execute(self, object, *multiparams, **params):
         for c in type(object).__mro__:
@@ -822,9 +790,9 @@ class Connection(Connectable):
 
     # poor man's multimethod/generic function thingy
     executors = {
-        sql._Function : _execute_function,
-        sql.ClauseElement : _execute_clauseelement,
-        sql.ClauseVisitor : _execute_compiled,
+        expression._Function : _execute_function,
+        expression.ClauseElement : _execute_clauseelement,
+        visitors.ClauseVisitor : _execute_compiled,
         schema.SchemaItem:_execute_default,
         str.__mro__[-2] : _execute_text
     }
@@ -989,14 +957,14 @@ class Engine(Connectable):
             connection.close()
 
     def _func(self):
-        return sql._FunctionGenerator(bind=self)
+        return expression._FunctionGenerator(bind=self)
 
     func = property(_func)
 
     def text(self, text, *args, **kwargs):
         """Return a sql.text() object for performing literal queries."""
 
-        return sql.text(text, bind=self, *args, **kwargs)
+        return expression.text(text, bind=self, *args, **kwargs)
 
     def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
         if connection is None:
@@ -1004,7 +972,7 @@ class Engine(Connectable):
         else:
             conn = connection
         try:
-            visitorcallable(conn, **kwargs).traverse(element)
+            visitorcallable(self.dialect, conn, **kwargs).traverse(element)
         finally:
             if connection is None:
                 conn.close()
@@ -1057,8 +1025,8 @@ class Engine(Connectable):
         connection = self.contextual_connect(close_with_result=True)
         return connection._execute_compiled(compiled, multiparams, params)
 
-    def compiler(self, statement, parameters, **kwargs):
-        return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
+    def statement_compiler(self, statement, parameters, **kwargs):
+        return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
 
     def connect(self, **kwargs):
         """Return a newly allocated Connection object."""
@@ -1159,6 +1127,7 @@ class ResultProxy(object):
         self.closed = False
         self.cursor = context.cursor
         self.__echo = logging.is_debug_enabled(context.engine.logger)
+        self._process_row = self._row_processor()
         if context.is_select():
             self._init_metadata()
             self._rowcount = None
@@ -1222,7 +1191,7 @@ class ResultProxy(object):
                 rec = props[key]
             elif isinstance(key, basestring) and key.lower() in props:
                 rec = props[key.lower()]
-            elif isinstance(key, sql.ColumnElement):
+            elif isinstance(key, expression.ColumnElement):
                 label = context.column_labels.get(key._label, key.name).lower()
                 if label in props:
                     rec = props[label]
@@ -1320,21 +1289,21 @@ class ResultProxy(object):
         return self.cursor.fetchmany(size)
     def _fetchall_impl(self):
         return self.cursor.fetchall()
+
+    def _row_processor(self):
+        return RowProxy
         
-    def _process_row(self, row):
-        return RowProxy(self, row)
-            
     def fetchall(self):
         """Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
 
-        l = [self._process_row(row) for row in self._fetchall_impl()]
+        l = [self._process_row(self, row) for row in self._fetchall_impl()]
         self.close()
         return l
 
     def fetchmany(self, size=None):
         """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
 
-        l = [self._process_row(row) for row in self._fetchmany_impl(size)]
+        l = [self._process_row(self, row) for row in self._fetchmany_impl(size)]
         if len(l) == 0:
             self.close()
         return l
@@ -1343,7 +1312,7 @@ class ResultProxy(object):
         """Fetch one row, just like DBAPI ``cursor.fetchone()``."""
         row = self._fetchone_impl()
         if row is not None:
-            return self._process_row(row)
+            return self._process_row(self, row)
         else:
             self.close()
             return None
@@ -1353,7 +1322,7 @@ class ResultProxy(object):
         row = self._fetchone_impl()
         try:
             if row is not None:
-                return self._process_row(row)[0]
+                return self._process_row(self, row)[0]
             else:
                 return None
         finally:
@@ -1425,11 +1394,9 @@ class BufferedColumnResultProxy(ResultProxy):
     def _get_col(self, row, key):
         rec = self._key_cache[key]
         return row[rec[2]]
-    
-    def _process_row(self, row):
-        sup = super(BufferedColumnResultProxy, self)
-        row = [sup._get_col(row, i) for i in xrange(len(row))]
-        return RowProxy(self, row)
+
+    def _row_processor(self):
+        return BufferedColumnRow
 
     def fetchall(self):
         l = []
@@ -1523,6 +1490,11 @@ class RowProxy(object):
     def __len__(self):
         return len(self.__row)
 
+class BufferedColumnRow(RowProxy):
+    def __init__(self, parent, row):
+        row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
+        super(BufferedColumnRow, self).__init__(parent, row)
+
 class SchemaIterator(schema.SchemaVisitor):
     """A visitor that can gather text into a buffer and execute the contents of the buffer."""
 
@@ -1590,11 +1562,11 @@ class DefaultRunner(schema.SchemaVisitor):
         return None
 
     def exec_default_sql(self, default):
-        c = sql.select([default.arg]).compile(bind=self.connection)
+        c = expression.select([default.arg]).compile(bind=self.connection)
         return self.connection._execute_compiled(c).scalar()
 
     def visit_column_onupdate(self, onupdate):
-        if isinstance(onupdate.arg, sql.ClauseElement):
+        if isinstance(onupdate.arg, expression.ClauseElement):
             return self.exec_default_sql(onupdate)
         elif callable(onupdate.arg):
             return onupdate.arg(self.context)
@@ -1602,7 +1574,7 @@ class DefaultRunner(schema.SchemaVisitor):
             return onupdate.arg
 
     def visit_column_default(self, default):
-        if isinstance(default.arg, sql.ClauseElement):
+        if isinstance(default.arg, expression.ClauseElement):
             return self.exec_default_sql(default)
         elif callable(default.arg):
             return default.arg(self.context)
index ccaf080e75826892e4f534132f3af828a1ee3a4e..059395921fcbd3b13c269de8dfb74a68beffb75f 100644 (file)
@@ -9,7 +9,7 @@
 from sqlalchemy import schema, exceptions, sql, util
 import re, random
 from sqlalchemy.engine import base
-
+from sqlalchemy.sql import compiler, expression
 
 AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
                                re.I | re.UNICODE)
@@ -18,6 +18,12 @@ SELECT_REGEXP = re.compile(r'\s*SELECT', re.I | re.UNICODE)
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""
 
+    schemagenerator = compiler.SchemaGenerator
+    schemadropper = compiler.SchemaDropper
+    statement_compiler = compiler.DefaultCompiler
+    preparer = compiler.IdentifierPreparer
+    defaultrunner = base.DefaultRunner
+
     def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
         self.convert_unicode = convert_unicode
         self.encoding = encoding
@@ -25,6 +31,7 @@ class DefaultDialect(base.Dialect):
         self._ischema = None
         self.dbapi = dbapi
         self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
+        self.identifier_preparer = self.preparer(self)
     
     def dbapi_type_map(self):
         # most DBAPIs have problems with this (such as, psycocpg2 types 
@@ -46,6 +53,7 @@ class DefaultDialect(base.Dialect):
             typeobj = typeobj()
         return typeobj
 
+
     def supports_unicode_statements(self):
         """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
         return False
@@ -96,13 +104,13 @@ class DefaultDialect(base.Dialect):
         return "_sa_%032x" % random.randint(0,2**128)
         
     def do_savepoint(self, connection, name):
-        connection.execute(sql.SavepointClause(name))
+        connection.execute(expression.SavepointClause(name))
 
     def do_rollback_to_savepoint(self, connection, name):
-        connection.execute(sql.RollbackToSavepointClause(name))
+        connection.execute(expression.RollbackToSavepointClause(name))
 
     def do_release_savepoint(self, connection, name):
-        connection.execute(sql.ReleaseSavepointClause(name))
+        connection.execute(expression.ReleaseSavepointClause(name))
 
     def do_executemany(self, cursor, statement, parameters, **kwargs):
         cursor.executemany(statement, parameters)
@@ -110,8 +118,6 @@ class DefaultDialect(base.Dialect):
     def do_execute(self, cursor, statement, parameters, **kwargs):
         cursor.execute(statement, parameters)
 
-    def defaultrunner(self, context):
-        return base.DefaultRunner(context)
 
     def is_disconnect(self, e):
         return False
index 65d2ab3d2caf2c143fea350a625824e40819e612..258bddb4a5ab6be31a0ce7c7e5f04b8a6e38ecf0 100644 (file)
@@ -294,7 +294,7 @@ from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.ext.sessioncontext import SessionContext
 from sqlalchemy.exceptions import *
-
+from sqlalchemy.sql import expression
 
 _testsql = """
 CREATE TABLE books (
@@ -415,7 +415,7 @@ def _selectable_name(selectable):
         return x
 
 def class_for_table(selectable, **mapper_kwargs):
-    selectable = sql._selectable(selectable)
+    selectable = expression._selectable(selectable)
     mapname = 'Mapped' + _selectable_name(selectable)
     if isinstance(selectable, Table):
         klass = TableClassType(mapname, (object,), {})
@@ -499,7 +499,7 @@ class SqlSoup:
 
     def with_labels(self, item):
         # TODO give meaningful aliases
-        return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
+        return self.map(expression._selectable(item).select(use_labels=True).alias('foo'))
 
     def join(self, *args, **kwargs):
         j = join(*args, **kwargs)
index c54eee4381bb7251506a96d3e0e47a48506bfcbc..9000a8df58fdc00436ca2c30d4194edd9420a867 100644 (file)
@@ -6,6 +6,7 @@
 
 
 from sqlalchemy import util, logging, sql
+from sqlalchemy.sql import expression
 
 __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
            'MapperProperty', 'PropComparator', 'StrategizedProperty', 
@@ -363,7 +364,7 @@ class MapperProperty(object):
 
         return operator(self.comparator, value)
 
-class PropComparator(sql.ColumnOperators):
+class PropComparator(expression.ColumnOperators):
     """defines comparison operations for MapperProperty objects"""
     
     def expression_element(self):
index b4836841774503f790e74ff9ff6c11da5fe9abc0..60d4526ec31ca9b1c44286cf1d00cf93721a2d72 100644 (file)
@@ -6,7 +6,8 @@
 
 import weakref, warnings, operator
 from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy import sql_util as sqlutil
+from sqlalchemy.sql import expression
+from sqlalchemy.sql import util as sqlutil
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import ExtensionCarrier
 from sqlalchemy.orm import sync
@@ -77,7 +78,7 @@ class Mapper(object):
             raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
 
         for table in (local_table, select_table):
-            if table is not None and isinstance(table, sql._SelectBaseMixin):
+            if table is not None and isinstance(table, expression._SelectBaseMixin):
                 # some db's, noteably postgres, dont want to select from a select
                 # without an alias.  also if we make our own alias internally, then
                 # the configured properties on the mapper are not matched against the alias
@@ -438,7 +439,7 @@ class Mapper(object):
             # against the "mapped_table" of this mapper.
             equivalent_columns = self._get_equivalent_columns()
         
-            primary_key = sql.ColumnSet()
+            primary_key = expression.ColumnSet()
 
             for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
                 c = self.mapped_table.corresponding_column(col, raiseerr=False)
@@ -644,9 +645,9 @@ class Mapper(object):
             props = {}
             if self.properties is not None:
                 for key, prop in self.properties.iteritems():
-                    if sql.is_column(prop):
+                    if expression.is_column(prop):
                         props[key] = self.select_table.corresponding_column(prop)
-                    elif (isinstance(prop, list) and sql.is_column(prop[0])):
+                    elif (isinstance(prop, list) and expression.is_column(prop[0])):
                         props[key] = [self.select_table.corresponding_column(c) for c in prop]
             self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument)
 
@@ -768,7 +769,7 @@ class Mapper(object):
 
     def _create_prop_from_column(self, column):
         column = util.to_list(column)
-        if not sql.is_column(column[0]):
+        if not expression.is_column(column[0]):
             return None
         mapped_column = []
         for c in column:
index 670fcccc993a322e38b2c236091a5d619b41d049..20cbcb2351332b1c5d6e28b49a57358fe2146376 100644 (file)
@@ -11,7 +11,8 @@ operations.  PropertyLoader also relies upon the dependency.py module
 to handle flush-time dependency sorting and processing.
 """
 
-from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
+from sqlalchemy import sql, schema, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
index 44329468f5b048869228180beee2e38c4c9fe80a..5cbe19ce2c3c23354215e5a128b61057ca1ff534 100644 (file)
@@ -4,7 +4,9 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import expression, visitors
 from sqlalchemy.orm import mapper, object_mapper
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.interfaces import OperationContext, LoaderStack
@@ -312,7 +314,7 @@ class Query(object):
         clause = self._from_obj[-1]
 
         currenttables = [clause]
-        class FindJoinedTables(sql.NoColumnVisitor):
+        class FindJoinedTables(visitors.NoColumnVisitor):
             def visit_join(self, join):
                 currenttables.append(join.left)
                 currenttables.append(join.right)
@@ -836,7 +838,7 @@ class Query(object):
             # if theres an order by, add those columns to the column list
             # of the "rowcount" query we're going to make
             if order_by:
-                order_by = [sql._literal_as_text(o) for o in util.to_list(order_by) or []]
+                order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
                 cf = sql_util.ColumnFinder()
                 for o in order_by:
                     cf.traverse(o)
index 6565c8d775fae52975f744aa2a774c181f828271..bdb17e1d64da72d2561c074d295c3fdcf6f22662 100644 (file)
@@ -6,7 +6,9 @@
 
 """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
 
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, logging
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import visitors
 from sqlalchemy.orm import mapper, attributes
 from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
 from sqlalchemy.orm import session as sessionlib
@@ -292,7 +294,7 @@ class LazyLoader(AbstractRelationLoader):
             (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
         bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
 
-        class Visitor(sql.ClauseVisitor):
+        class Visitor(visitors.ClauseVisitor):
             def visit_bindparam(s, bindparam):
                 mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
                 if bindparam.key in bind_to_col:
@@ -396,7 +398,7 @@ class LazyLoader(AbstractRelationLoader):
             if not isinstance(expr, sql.ColumnElement):
                 return None
             columns = []
-            class FindColumnInColumnClause(sql.ClauseVisitor):
+            class FindColumnInColumnClause(visitors.ClauseVisitor):
                 def visit_column(self, c):
                     columns.append(c)
             FindColumnInColumnClause().traverse(expr)
index cf48202b0f9bcf0d3f2748b1c0bed7f12a16d2c6..49661a95ef691c9fcff21430f40c0461a8ca4719 100644 (file)
@@ -10,9 +10,9 @@ clause that compares column values.
 """
 
 from sqlalchemy import sql, schema, exceptions
+from sqlalchemy.sql import visitors, operators
 from sqlalchemy import logging
 from sqlalchemy.orm import util as mapperutil
-import operator
 
 ONETOMANY = 0
 MANYTOONE = 1
@@ -43,7 +43,7 @@ class ClauseSynchronizer(object):
         def compile_binary(binary):
             """Assemble a SyncRule given a single binary condition."""
 
-            if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+            if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                 return
 
             source_column = None
@@ -144,7 +144,7 @@ class SyncRule(object):
 
 SyncRule.logger = logging.class_logger(SyncRule)
 
-class BinaryVisitor(sql.ClauseVisitor):
+class BinaryVisitor(visitors.ClauseVisitor):
     def __init__(self, func):
         self.func = func
 
index 6b0956dc4d9f5d51102a705c4fabe71f9f27c139..b3f58c954b2f99b7a94f6fd5998fcee3a4c82f9b 100644 (file)
@@ -4,7 +4,9 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import sql, util, exceptions, sql_util
+from sqlalchemy import sql, util, exceptions
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import visitors
 from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
 
 all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
@@ -161,7 +163,7 @@ class ExtensionCarrier(MapperExtension):
     before_delete = _create_do('before_delete')
     after_delete = _create_do('after_delete')
 
-class BinaryVisitor(sql.ClauseVisitor):
+class BinaryVisitor(visitors.ClauseVisitor):
     def __init__(self, func):
         self.func = func
 
@@ -196,7 +198,7 @@ class AliasedClauses(object):
         # for column-level subqueries, swap out its selectable with our
         # eager version as appropriate, and manually build the 
         # "correlation" list of the subquery.  
-        class ModifySubquery(sql.ClauseVisitor):
+        class ModifySubquery(visitors.ClauseVisitor):
             def visit_select(s, select):
                 select._should_correlate = False
                 select.append_correlation(self.alias)
index 99803d665aebe7afa132f9d797fcf3ef389b9864..99ca2389bec31c9c2e76e10b659714e5054fcd6c 100644 (file)
@@ -18,8 +18,11 @@ objects as well as the visitor interface, so that the schema package
 """
 
 import re, inspect
-from sqlalchemy import sql, types, exceptions, util, databases
+from sqlalchemy import types, exceptions, util, databases
+from sqlalchemy.sql import expression, visitors
 import sqlalchemy
+
+
 URL = None
 
 __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
@@ -31,7 +34,7 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
 class SchemaItem(object):
     """Base class for items that define a database schema."""
 
-    __metaclass__ = sql._FigureVisitName
+    __metaclass__ = expression._FigureVisitName
 
     def _init_items(self, *args):
         """Initialize the list of child items for this SchemaItem."""
@@ -84,7 +87,7 @@ def _get_table_key(name, schema):
     else:
         return schema + "." + name
 
-class _TableSingleton(sql._FigureVisitName):
+class _TableSingleton(expression._FigureVisitName):
     """A metaclass used by the ``Table`` object to provide singleton behavior."""
 
     def __call__(self, name, metadata, *args, **kwargs):
@@ -124,10 +127,10 @@ class _TableSingleton(sql._FigureVisitName):
             return table
 
 
-class Table(SchemaItem, sql.TableClause):
+class Table(SchemaItem, expression.TableClause):
     """Represent a relational database table.
 
-    This subclasses ``sql.TableClause`` to provide a table that is
+    This subclasses ``expression.TableClause`` to provide a table that is
     associated with an instance of ``MetaData``, which in turn
     may be associated with an instance of ``Engine``.  
 
@@ -229,7 +232,7 @@ class Table(SchemaItem, sql.TableClause):
         self.schema = kwargs.pop('schema', None)
         self.indexes = util.Set()
         self.constraints = util.Set()
-        self._columns = sql.ColumnCollection()
+        self._columns = expression.ColumnCollection()
         self.primary_key = PrimaryKeyConstraint()
         self._foreign_keys = util.OrderedSet()
         self.quote = kwargs.pop('quote', False)
@@ -291,7 +294,7 @@ class Table(SchemaItem, sql.TableClause):
 
     def get_children(self, column_collections=True, schema_visitor=False, **kwargs):
         if not schema_visitor:
-            return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs)
+            return expression.TableClause.get_children(self, column_collections=column_collections, **kwargs)
         else:
             if column_collections:
                 return [c for c in self.columns]
@@ -338,10 +341,10 @@ class Table(SchemaItem, sql.TableClause):
                 args.append(c.copy())
             return Table(self.name, metadata, schema=schema, *args)
 
-class Column(SchemaItem, sql._ColumnClause):
+class Column(SchemaItem, expression._ColumnClause):
     """Represent a column in a database table.
 
-    This is a subclass of ``sql.ColumnClause`` and represents an
+    This is a subclass of ``expression.ColumnClause`` and represents an
     actual existing table in the database, in a similar fashion as
     ``TableClause``/``Table``.
     """
@@ -575,7 +578,7 @@ class Column(SchemaItem, sql._ColumnClause):
             return [x for x in (self.default, self.onupdate) if x is not None] + \
                 list(self.foreign_keys) + list(self.constraints)
         else:
-            return sql._ColumnClause.get_children(self, **kwargs)
+            return expression._ColumnClause.get_children(self, **kwargs)
 
 
 class ForeignKey(SchemaItem):
@@ -806,7 +809,7 @@ class Constraint(SchemaItem):
 
     def __init__(self, name=None):
         self.name = name
-        self.columns = sql.ColumnCollection()
+        self.columns = expression.ColumnCollection()
 
     def __contains__(self, x):
         return self.columns.contains_column(x)
@@ -1124,12 +1127,12 @@ class MetaData(SchemaItem):
         del self.tables[table.key]
         
     def table_iterator(self, reverse=True, tables=None):
-        import sqlalchemy.sql_util
+        from sqlalchemy.sql import util as sql_util
         if tables is None:
             tables = self.tables.values()
         else:
             tables = util.Set(tables).intersection(self.tables.values())
-        sorter = sqlalchemy.sql_util.TableCollection(list(tables))
+        sorter = sql_util.TableCollection(list(tables))
         return iter(sorter.sort(reverse=reverse))
 
     def _get_parent(self):
@@ -1356,7 +1359,7 @@ class ThreadLocalMetaData(MetaData):
                 e.dispose()
 
 
-class SchemaVisitor(sql.ClauseVisitor):
+class SchemaVisitor(visitors.ClauseVisitor):
     """Define the visiting for ``SchemaItem`` objects."""
 
     __traverse_options__ = {'schema_visitor':True}
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
new file mode 100644 (file)
index 0000000..06b9f1f
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.sql.expression import *
+from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor
+
similarity index 93%
rename from lib/sqlalchemy/ansisql.py
rename to lib/sqlalchemy/sql/compiler.py
index 5f5e1c1713ec395a50406463888af87cc8e84af5..6053c72be941e24effdbe37fdb6675c997199055 100644 (file)
@@ -1,22 +1,18 @@
-# ansisql.py
+# compiler.py
 # Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Defines ANSI SQL operations.
+"""SQL expression compilation routines and DDL implementations."""
 
-Contains default implementations for the abstract objects in the sql
-module.
-"""
+import string, re
+from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import expression as sql
 
-import string, re, sets, operator
-
-from sqlalchemy import schema, sql, engine, util, exceptions, operators
-from  sqlalchemy.engine import default
-
-
-ANSI_FUNCS = sets.ImmutableSet([
+ANSI_FUNCS = util.Set([
     'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
     'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
     'SESSION_USER', 'USER'])
@@ -77,7 +73,6 @@ OPERATORS =  {
     operators.comma_op : ', ',
     operators.desc_op : 'DESC',
     operators.asc_op : 'ASC',
-    
     operators.from_ : 'FROM',
     operators.as_ : 'AS',
     operators.exists : 'EXISTS',
@@ -85,36 +80,10 @@ OPERATORS =  {
     operators.isnot : 'IS NOT'
 }
 
-class ANSIDialect(default.DefaultDialect):
-    def __init__(self, cache_identifiers=True, **kwargs):
-        super(ANSIDialect,self).__init__(**kwargs)
-        self.identifier_preparer = self.preparer()
-        self.cache_identifiers = cache_identifiers
-
-    def create_connect_args(self):
-        return ([],{})
-
-    def schemagenerator(self, *args, **kwargs):
-        return ANSISchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return ANSISchemaDropper(self, *args, **kwargs)
-
-    def compiler(self, statement, parameters, **kwargs):
-        return ANSICompiler(self, statement, parameters, **kwargs)
-
-    def preparer(self):
-        """Return an IdentifierPreparer.
-
-        This object is used to format table and column names including
-        proper quoting and case conventions.
-        """
-        return ANSIIdentifierPreparer(self)
-
-class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
+class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
     """Default implementation of Compiled.
 
-    Compiles ClauseElements into ANSI-compliant SQL strings.
+    Compiles ClauseElements into SQL strings.
     """
 
     __traverse_options__ = {'column_collections':False, 'entry':True}
@@ -122,7 +91,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
     operators = OPERATORS
     
     def __init__(self, dialect, statement, parameters=None, **kwargs):
-        """Construct a new ``ANSICompiler`` object.
+        """Construct a new ``DefaultCompiler`` object.
 
         dialect
           Dialect to be used
@@ -139,7 +108,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
           correspond to the keys present in the parameters.
         """
         
-        super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
+        super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs)
 
         # if we are insert/update.  set to true when we visit an INSERT or UPDATE
         self.isinsert = self.isupdate = False
@@ -170,17 +139,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
         self.bindtemplate = ":%s"
 
         # paramstyle from the dialect (comes from DBAPI)
-        self.paramstyle = dialect.paramstyle
+        self.paramstyle = self.dialect.paramstyle
 
         # true if the paramstyle is positional
-        self.positional = dialect.positional
+        self.positional = self.dialect.positional
 
         # a list of the compiled's bind parameter names, used to help
         # formulate a positional argument list
         self.positiontup = []
 
-        # an ANSIIdentifierPreparer that formats the quoting of identifiers
-        self.preparer = dialect.identifier_preparer
+        # an IdentifierPreparer that formats the quoting of identifiers
+        self.preparer = self.dialect.identifier_preparer
         
         # for UPDATE and INSERT statements, a set of columns whos values are being set
         # from a SQL expression (i.e., not one of the bind parameter values).  if present,
@@ -244,7 +213,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
         return None
 
     def construct_params(self, params):
-        """Return a sql.ClauseParameters object.
+        """Return a sql.util.ClauseParameters object.
         
         Combines the given bind parameter dictionary (string keys to object values)
         with the _BindParamClause objects stored within this Compiled object
@@ -252,7 +221,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
         for a single statement execution, or one element of an executemany execution.
         """
         
-        d = sql.ClauseParameters(self.dialect, self.positiontup)
+        d = sql_util.ClauseParameters(self.dialect, self.positiontup)
 
         pd = self.parameters or {}
         pd.update(params)
@@ -781,7 +750,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
     def __str__(self):
         return self.string
 
-class ANSISchemaBase(engine.SchemaIterator):
+class DDLBase(engine.SchemaIterator):
     def find_alterables(self, tables):
         alterables = []
         class FindAlterables(schema.SchemaVisitor):
@@ -794,12 +763,12 @@ class ANSISchemaBase(engine.SchemaIterator):
                 findalterables.traverse(c)
         return alterables
 
-class ANSISchemaGenerator(ANSISchemaBase):
+class SchemaGenerator(DDLBase):
     def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
-        super(ANSISchemaGenerator, self).__init__(connection, **kwargs)
+        super(SchemaGenerator, self).__init__(connection, **kwargs)
         self.checkfirst = checkfirst
         self.tables = tables and util.Set(tables) or None
-        self.preparer = dialect.preparer()
+        self.preparer = dialect.identifier_preparer
         self.dialect = dialect
 
     def get_column_specification(self, column, first_pk=False):
@@ -860,7 +829,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
 
     def _compile(self, tocompile, parameters):
         """compile the given string/parameters using this SchemaGenerator's dialect."""
-        compiler = self.dialect.compiler(tocompile, parameters)
+        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
         compiler.compile()
         return compiler
 
@@ -930,12 +899,12 @@ class ANSISchemaGenerator(ANSISchemaBase):
                        string.join([preparer.format_column(c) for c in index.columns], ', ')))
         self.execute()
 
-class ANSISchemaDropper(ANSISchemaBase):
+class SchemaDropper(DDLBase):
     def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
-        super(ANSISchemaDropper, self).__init__(connection, **kwargs)
+        super(SchemaDropper, self).__init__(connection, **kwargs)
         self.checkfirst = checkfirst
         self.tables = tables
-        self.preparer = dialect.preparer()
+        self.preparer = dialect.identifier_preparer
         self.dialect = dialect
 
     def visit_metadata(self, metadata):
@@ -964,14 +933,11 @@ class ANSISchemaDropper(ANSISchemaBase):
         self.append("\nDROP TABLE " + self.preparer.format_table(table))
         self.execute()
 
-class ANSIDefaultRunner(engine.DefaultRunner):
-    pass
-
-class ANSIIdentifierPreparer(object):
+class IdentifierPreparer(object):
     """Handle quoting and case-folding of identifiers based on options."""
 
     def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
-        """Construct a new ``ANSIIdentifierPreparer`` object.
+        """Construct a new ``IdentifierPreparer`` object.
 
         initial_quote
           Character that begins a delimited identifier.
@@ -1049,20 +1015,14 @@ class ANSIIdentifierPreparer(object):
     def __generic_obj_format(self, obj, ident):
         if getattr(obj, 'quote', False):
             return self.quote_identifier(ident)
-        if self.dialect.cache_identifiers:
-            try:
-                return self.__strings[ident]
-            except KeyError:
-                if self._requires_quotes(ident):
-                    self.__strings[ident] = self.quote_identifier(ident)
-                else:
-                    self.__strings[ident] = ident
-                return self.__strings[ident]
-        else:
+        try:
+            return self.__strings[ident]
+        except KeyError:
             if self._requires_quotes(ident):
-                return self.quote_identifier(ident)
+                self.__strings[ident] = self.quote_identifier(ident)
             else:
-                return ident
+                self.__strings[ident] = ident
+            return self.__strings[ident]
 
     def should_quote(self, object):
         return object.quote or self._requires_quotes(object.name)
@@ -1152,5 +1112,3 @@ class ANSIIdentifierPreparer(object):
         return [self._unescape_identifier(i)
                 for i in [a or b for a, b in r.findall(identifiers)]]
 
-
-dialect = ANSIDialect
similarity index 93%
rename from lib/sqlalchemy/sql.py
rename to lib/sqlalchemy/sql/expression.py
index 7c73f7cb7ff2b0fc853c5119406db5b68d9cdc34..e117e3f47ab33112d61b8a29180b8d6af62cb27b 100644 (file)
@@ -26,13 +26,14 @@ classes usually have few or no public methods and are less guaranteed
 to stay the same in future releases.
 """
 
-from sqlalchemy import util, exceptions, operators
+from sqlalchemy import util, exceptions
+from sqlalchemy.sql import operators, visitors
 from sqlalchemy import types as sqltypes
 import re
 
 __all__ = [
-    'Alias', 'ClauseElement', 'ClauseParameters',
-    'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
+    'Alias', 'ClauseElement', 
+    'ColumnCollection', 'ColumnElement',
     'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
     'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
     'between', 'bindparam', 'case', 'cast', 'column', 'delete',
@@ -810,187 +811,6 @@ def is_column(col):
     """True if ``col`` is an instance of ``ColumnElement``."""
     return isinstance(col, ColumnElement)
 
-class ClauseParameters(object):
-    """Represent a dictionary/iterator of bind parameter key names/values.
-
-    Tracks the original [sqlalchemy.sql#_BindParamClause] objects
-    as well as the keys/position of each parameter, and can return
-    parameters as a dictionary or a list.  Will process parameter
-    values according to the ``TypeEngine`` objects present in the
-    ``_BindParamClause`` instances.
-    """
-
-    def __init__(self, dialect, positional=None):
-        self.dialect = dialect
-        self.__binds = {}
-        self.positional = positional or []
-
-    def get_parameter(self, key):
-        return self.__binds[key]
-
-    def set_parameter(self, bindparam, value, name):
-        self.__binds[name] = [bindparam, name, value]
-
-    def get_original(self, key):
-        return self.__binds[key][2]
-
-    def get_type(self, key):
-        return self.__binds[key][0].type
-
-    def get_processors(self):
-        """return a dictionary of bind 'processing' functions"""
-        return dict([
-            (key, value) for key, value in
-            [(
-                key,
-                self.__binds[key][0].bind_processor(self.dialect)
-            ) for key in self.__binds]
-            if value is not None
-        ])
-
-    def get_processed(self, key, processors):
-        return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
-
-    def keys(self):
-        return self.__binds.keys()
-
-    def __iter__(self):
-        return iter(self.keys())
-
-    def __getitem__(self, key):
-        (bind, name, value) = self.__binds[key]
-        processor = bind.bind_processor(self.dialect)
-        return processor is not None and processor(value) or value
-
-    def __contains__(self, key):
-        return key in self.__binds
-
-    def set_value(self, key, value):
-        self.__binds[key][2] = value
-
-    def get_original_dict(self):
-        return dict([(name, value) for (b, name, value) in self.__binds.values()])
-
-    def __get_processed(self, key, processors):
-        if key in processors:
-            return processors[key](self.__binds[key][2])
-        else:
-            return self.__binds[key][2]
-
-    def get_raw_list(self, processors):
-        return [self.__get_processed(key, processors) for key in self.positional]
-
-    def get_raw_dict(self, processors, encode_keys=False):
-        if encode_keys:
-            return dict([
-                (
-                    key.encode(self.dialect.encoding),
-                    self.__get_processed(key, processors)
-                )
-                for key in self.keys()
-            ])
-        else:
-            return dict([
-                (
-                    key,
-                    self.__get_processed(key, processors)
-                )
-                for key in self.keys()
-            ])
-
-    def __repr__(self):
-        return self.__class__.__name__ + ":" + repr(self.get_original_dict())
-
-class ClauseVisitor(object):
-    """A class that knows how to traverse and visit ``ClauseElements``.
-
-    Calls visit_XXX() methods dynamically generated for each
-    particualr ``ClauseElement`` subclass encountered.  Traversal of a
-    hierarchy of ``ClauseElements`` is achieved via the ``traverse()``
-    method, which is passed the lead ``ClauseElement``.
-
-    By default, ``ClauseVisitor`` traverses all elements fully.
-    Options can be specified at the class level via the
-    ``__traverse_options__`` dictionary which will be passed to the
-    ``get_children()`` method of each ``ClauseElement``; these options
-    can indicate modifications to the set of elements returned, such
-    as to not return column collections (column_collections=False) or
-    to return Schema-level items (schema_visitor=True).
-
-    ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
-    operation, which will produce a copy of a given ``ClauseElement``
-    structure while at the same time allowing ``ClauseVisitor``
-    subclasses to modify the new structure in-place.
-    """
-
-    __traverse_options__ = {}
-
-    def traverse_single(self, obj, **kwargs):
-        meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
-        if meth:
-            return meth(obj, **kwargs)
-
-    def iterate(self, obj, stop_on=None):
-        stack = [obj]
-        traversal = []
-        while len(stack) > 0:
-            t = stack.pop()
-            if stop_on is None or t not in stop_on:
-                yield t
-                traversal.insert(0, t)
-                for c in t.get_children(**self.__traverse_options__):
-                    stack.append(c)
-
-    def traverse(self, obj, stop_on=None, clone=False):
-        if clone:
-            obj = obj._clone()
-
-        stack = [obj]
-        traversal = []
-        while len(stack) > 0:
-            t = stack.pop()
-            if stop_on is None or t not in stop_on:
-                traversal.insert(0, t)
-                if clone:
-                    t._copy_internals()
-                for c in t.get_children(**self.__traverse_options__):
-                    stack.append(c)
-        for target in traversal:
-            v = self
-            while v is not None:
-                meth = getattr(v, "visit_%s" % target.__visit_name__, None)
-                if meth:
-                    meth(target)
-                v = getattr(v, '_next', None)
-        return obj
-
-    def chain(self, visitor):
-        """'chain' an additional ClauseVisitor onto this ClauseVisitor.
-
-        The chained visitor will receive all visit events after this one.
-        """
-
-        tail = self
-        while getattr(tail, '_next', None) is not None:
-            tail = tail._next
-        tail._next = visitor
-        return self
-
-class NoColumnVisitor(ClauseVisitor):
-    """A ClauseVisitor that will not traverse exported column collections.
-
-    Will not traverse the exported Column collections on Table, Alias,
-    Select, and CompoundSelect objects (i.e. their 'columns' or 'c'
-    attribute).
-
-    This is useful because most traversals don't need those columns,
-    or in the case of ANSICompiler it traverses them explicitly; so
-    skipping their traversal here greatly cuts down on method call
-    overhead.
-    """
-
-    __traverse_options__ = {'column_collections': False}
-
 
 class _FigureVisitName(type):
     def __init__(cls, clsname, bases, dict):
@@ -1061,7 +881,7 @@ class ClauseElement(object):
         elif len(optionaldict) > 1:
             raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
 
-        class Vis(ClauseVisitor):
+        class Vis(visitors.ClauseVisitor):
             def visit_bindparam(self, bind):
                 if bind.key in kwargs:
                     bind.value = kwargs[bind.key]
@@ -1156,7 +976,7 @@ class ClauseElement(object):
         if any.
 
         Finally, if there is no bound ``Engine``, uses an
-        ``ANSIDialect`` to create a default ``Compiler``.
+        ``DefaultDialect`` to create a default ``Compiler``.
 
         `parameters` is a dictionary representing the default bind
         parameters to be used with the statement.  If `parameters` is
@@ -1175,15 +995,16 @@ class ClauseElement(object):
 
         if compiler is None:
             if dialect is not None:
-                compiler = dialect.compiler(self, parameters)
+                compiler = dialect.statement_compiler(dialect, self, parameters)
             elif bind is not None:
-                compiler = bind.compiler(self, parameters)
+                compiler = bind.statement_compiler(self, parameters)
             elif self.bind is not None:
-                compiler = self.bind.compiler(self, parameters)
+                compiler = self.bind.statement_compiler(self, parameters)
 
         if compiler is None:
-            import sqlalchemy.ansisql as ansisql
-            compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters)
+            from sqlalchemy.engine.default import DefaultDialect
+            dialect = DefaultDialect()
+            compiler = dialect.statement_compiler(dialect, self, parameters=parameters)
         compiler.compile()
         return compiler
 
@@ -1727,7 +1548,7 @@ class FromClause(Selectable):
 
     def _get_all_embedded_columns(self):
         ret = []
-        class FindCols(ClauseVisitor):
+        class FindCols(visitors.ClauseVisitor):
             def visit_column(self, col):
                 ret.append(col)
         FindCols().traverse(self)
@@ -1744,8 +1565,8 @@ class FromClause(Selectable):
     def replace_selectable(self, old, alias):
       """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
 
-      from sqlalchemy import sql_util
-      return sql_util.ClauseAdapter(alias).traverse(self, clone=True)
+      from sqlalchemy.sql import util
+      return util.ClauseAdapter(alias).traverse(self, clone=True)
 
     def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
         """Given a ``ColumnElement``, return the exported ``ColumnElement``
@@ -2376,7 +2197,7 @@ class Join(FromClause):
                 else:
                     equivs[x] = util.Set([y])
 
-        class BinaryVisitor(ClauseVisitor):
+        class BinaryVisitor(visitors.ClauseVisitor):
             def visit_binary(self, binary):
                 if binary.operator == operators.eq:
                     add_equiv(binary.left, binary.right)
@@ -2460,7 +2281,7 @@ class Join(FromClause):
             return self.__folded_equivalents
         if equivs is None:
             equivs = util.Set()
-        class LocateEquivs(NoColumnVisitor):
+        class LocateEquivs(visitors.NoColumnVisitor):
             def visit_binary(self, binary):
                 if binary.operator == operators.eq and binary.left.name == binary.right.name:
                     equivs.add(binary.right)
@@ -3331,7 +3152,7 @@ class Select(_SelectBaseMixin, FromClause):
         return intersect_all(self, other, **kwargs)
 
     def _table_iterator(self):
-        for t in NoColumnVisitor().iterate(self):
+        for t in visitors.NoColumnVisitor().iterate(self):
             if isinstance(t, TableClause):
                 yield t
 
similarity index 68%
rename from lib/sqlalchemy/sql_util.py
rename to lib/sqlalchemy/sql/util.py
index cc6325822486beecd4214fd87643e8bc8c90a797..2c7294e663c4ef3267cceff091259227fcaeec5e 100644 (file)
@@ -1,7 +1,100 @@
-from sqlalchemy import sql, util, schema, topological
+from sqlalchemy import util, schema, topological
+from sqlalchemy.sql import expression, visitors
 
 """Utility functions that build upon SQL and Schema constructs."""
 
+class ClauseParameters(object):
+    """Represent a dictionary/iterator of bind parameter key names/values.
+
+    Tracks the original [sqlalchemy.sql#_BindParamClause] objects as well as the
+    keys/position of each parameter, and can return parameters as a
+    dictionary or a list.  Will process parameter values according to
+    the ``TypeEngine`` objects present in the ``_BindParamClause`` instances.
+    """
+
+    def __init__(self, dialect, positional=None):
+        self.dialect = dialect
+        self.__binds = {}
+        self.positional = positional or []
+
+    def get_parameter(self, key):
+        return self.__binds[key]
+
+    def set_parameter(self, bindparam, value, name):
+        self.__binds[name] = [bindparam, name, value]
+        
+    def get_original(self, key):
+        return self.__binds[key][2]
+
+    def get_type(self, key):
+        return self.__binds[key][0].type
+
+    def get_processors(self):
+        """return a dictionary of bind 'processing' functions"""
+        return dict([
+            (key, value) for key, value in 
+            [(
+                key,
+                self.__binds[key][0].bind_processor(self.dialect)
+            ) for key in self.__binds]
+            if value is not None
+        ])
+    
+    def get_processed(self, key, processors):
+        return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
+            
+    def keys(self):
+        return self.__binds.keys()
+
+    def __iter__(self):
+        return iter(self.keys())
+        
+    def __getitem__(self, key):
+        (bind, name, value) = self.__binds[key]
+        processor = bind.bind_processor(self.dialect)
+        return processor is not None and processor(value) or value
+    def __contains__(self, key):
+        return key in self.__binds
+    
+    def set_value(self, key, value):
+        self.__binds[key][2] = value
+            
+    def get_original_dict(self):
+        return dict([(name, value) for (b, name, value) in self.__binds.values()])
+
+    def __get_processed(self, key, processors):
+        if key in processors:
+            return processors[key](self.__binds[key][2])
+        else:
+            return self.__binds[key][2]
+            
+    def get_raw_list(self, processors):
+        return [self.__get_processed(key, processors) for key in self.positional]
+
+    def get_raw_dict(self, processors, encode_keys=False):
+        if encode_keys:
+            return dict([
+                (
+                    key.encode(self.dialect.encoding),
+                    self.__get_processed(key, processors)
+                )
+                for key in self.keys()
+            ])
+        else:
+            return dict([
+                (
+                    key,
+                    self.__get_processed(key, processors)
+                )
+                for key in self.keys()
+            ])
+
+    def __repr__(self):
+        return self.__class__.__name__ + ":" + repr(self.get_original_dict())
+
+
+
 class TableCollection(object):
     def __init__(self, tables=None):
         self.tables = tables or []
@@ -64,7 +157,7 @@ class TableCollection(object):
         return sequence
 
 
-class TableFinder(TableCollection, sql.NoColumnVisitor):
+class TableFinder(TableCollection, visitors.NoColumnVisitor):
     """locate all Tables within a clause."""
 
     def __init__(self, clause, check_columns=False, include_aliases=False):
@@ -85,7 +178,7 @@ class TableFinder(TableCollection, sql.NoColumnVisitor):
         if self.check_columns:
             self.tables.append(column.table)
 
-class ColumnFinder(sql.ClauseVisitor):
+class ColumnFinder(visitors.ClauseVisitor):
     def __init__(self):
         self.columns = util.Set()
 
@@ -95,7 +188,7 @@ class ColumnFinder(sql.ClauseVisitor):
     def __iter__(self):
         return iter(self.columns)
 
-class ColumnsInClause(sql.ClauseVisitor):
+class ColumnsInClause(visitors.ClauseVisitor):
     """Given a selectable, visit clauses and determine if any columns
     from the clause are in the selectable.
     """
@@ -108,7 +201,7 @@ class ColumnsInClause(sql.ClauseVisitor):
         if self.selectable.c.get(column.key) is column:
             self.result = True
 
-class AbstractClauseProcessor(sql.NoColumnVisitor):
+class AbstractClauseProcessor(visitors.NoColumnVisitor):
     """Traverse a clause and attempt to convert the contents of container elements
     to a converted element.
 
@@ -224,10 +317,10 @@ class ClauseAdapter(AbstractClauseProcessor):
         self.equivalents = equivalents
 
     def convert_element(self, col):
-        if isinstance(col, sql.FromClause):
+        if isinstance(col, expression.FromClause):
             if self.selectable.is_derived_from(col):
                 return self.selectable
-        if not isinstance(col, sql.ColumnElement):
+        if not isinstance(col, expression.ColumnElement):
             return None
         if self.include is not None:
             if col not in self.include:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
new file mode 100644 (file)
index 0000000..98e4de6
--- /dev/null
@@ -0,0 +1,87 @@
+class ClauseVisitor(object):
+    """A class that knows how to traverse and visit
+    ``ClauseElements``.
+    
+    Calls visit_XXX() methods dynamically generated for each particualr
+    ``ClauseElement`` subclass encountered.  Traversal of a
+    hierarchy of ``ClauseElements`` is achieved via the
+    ``traverse()`` method, which is passed the lead
+    ``ClauseElement``.
+    
+    By default, ``ClauseVisitor`` traverses all elements
+    fully.  Options can be specified at the class level via the 
+    ``__traverse_options__`` dictionary which will be passed
+    to the ``get_children()`` method of each ``ClauseElement``;
+    these options can indicate modifications to the set of 
+    elements returned, such as to not return column collections
+    (column_collections=False) or to return Schema-level items
+    (schema_visitor=True).
+    
+    ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+    operation, which will produce a copy of a given ``ClauseElement``
+    structure while at the same time allowing ``ClauseVisitor`` subclasses
+    to modify the new structure in-place.
+    
+    """
+    __traverse_options__ = {}
+    
+    def traverse_single(self, obj, **kwargs):
+        meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+        if meth:
+            return meth(obj, **kwargs)
+
+    def iterate(self, obj, stop_on=None):
+        stack = [obj]
+        traversal = []
+        while len(stack) > 0:
+            t = stack.pop()
+            if stop_on is None or t not in stop_on:
+                yield t
+                traversal.insert(0, t)
+                for c in t.get_children(**self.__traverse_options__):
+                    stack.append(c)
+        
+    def traverse(self, obj, stop_on=None, clone=False):
+        if clone:
+            obj = obj._clone()
+            
+        stack = [obj]
+        traversal = []
+        while len(stack) > 0:
+            t = stack.pop()
+            if stop_on is None or t not in stop_on:
+                traversal.insert(0, t)
+                if clone:
+                    t._copy_internals()
+                for c in t.get_children(**self.__traverse_options__):
+                    stack.append(c)
+        for target in traversal:
+            v = self
+            while v is not None:
+                meth = getattr(v, "visit_%s" % target.__visit_name__, None)
+                if meth:
+                    meth(target)
+                v = getattr(v, '_next', None)
+        return obj
+
+    def chain(self, visitor):
+        """'chain' an additional ClauseVisitor onto this ClauseVisitor.
+        
+        the chained visitor will receive all visit events after this one."""
+        tail = self
+        while getattr(tail, '_next', None) is not None:
+            tail = tail._next
+        tail._next = visitor
+        return self
+
+class NoColumnVisitor(ClauseVisitor):
+    """a ClauseVisitor that will not traverse the exported Column 
+    collections on Table, Alias, Select, and CompoundSelect objects
+    (i.e. their 'columns' or 'c' attribute).
+    
+    this is useful because most traversals don't need those columns, or
+    in the case of DefaultCompiler it traverses them explicitly; so
+    skipping their traversal here greatly cuts down on method call overhead.
+    """
+    
+    __traverse_options__ = {'column_collections':False}
index 46e0a71376c9f36957026b9c298df3906af3bc61..2948428543b2a25ca1999ce77cc3a02afaa24885 100644 (file)
@@ -154,7 +154,7 @@ class TypesTest(AssertMixin):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         numeric_table = Table(*table_args)
-        gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
+        gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None)
         
         for col in numeric_table.c:
             index = int(col.name[1:])
@@ -238,7 +238,7 @@ class TypesTest(AssertMixin):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         charset_table = Table(*table_args)
-        gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
+        gen = testbase.db.dialect.schemagenerator(testbase.db.dialect, testbase.db, None, None)
         
         for col in charset_table.c:
             index = int(col.name[1:])
@@ -707,7 +707,7 @@ class SQLTest(AssertMixin):
 
 
 def colspec(c):
-    return testbase.db.dialect.schemagenerator(
+    return testbase.db.dialect.schemagenerator(testbase.db.dialect, 
         testbase.db, None, None).get_column_specification(c)
 
 if __name__ == "__main__":
index da6f75149b82e91930cfe820469143c57ff1bc64..2345a328ae3beef21399ab76523239bee99abaa6 100644 (file)
@@ -2,7 +2,6 @@ import testbase
 import pickle, StringIO, unicodedata
 
 from sqlalchemy import *
-import sqlalchemy.ansisql as ansisql
 from sqlalchemy.exceptions import NoSuchTableError
 from testlib import *
 from testlib import engines
@@ -686,7 +685,7 @@ class SchemaTest(PersistTest):
         def foo(s, p=None):
             buf.write(s)
         gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo)
-        gen = gen.dialect.schemagenerator(gen)
+        gen = gen.dialect.schemagenerator(gen.dialect, gen)
         gen.traverse(table1)
         gen.traverse(table2)
         buf = buf.getvalue()
index 0c824d372d5a09544c1062f317ac43c43d515d4d..2e84c83d18c611605b3b8bfd5141552b495c6055 100644 (file)
@@ -1,7 +1,6 @@
 import testbase
 import operator
 from sqlalchemy import *
-from sqlalchemy import ansisql
 from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
index e0b8bf4f3f7911e7363d6f46bc3869aa7d5a2bce..e3f6ed42c819f53c4b70a9ceb55e3cb6b8c27c78 100644 (file)
@@ -1,7 +1,8 @@
 import testbase
 import operator
+from sqlalchemy.sql import compiler
 from sqlalchemy import *
-from sqlalchemy import ansisql
+from sqlalchemy.engine import default
 from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
@@ -141,7 +142,7 @@ class OperatorTest(QueryTest):
     """test sql.Comparator implementation for MapperProperties"""
     
     def _test(self, clause, expected):
-        c = str(clause.compile(dialect=ansisql.ANSIDialect()))
+        c = str(clause.compile(dialect = default.DefaultDialect()))
         assert c == expected, "%s != %s" % (c, expected)
         
     def test_arithmetic(self):
@@ -182,7 +183,7 @@ class OperatorTest(QueryTest):
 
                 # the compiled clause should match either (e.g.):
                 # 'a' < 'b' -or- 'b' > 'a'.
-                compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect()))
+                compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect()))
                 fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
                 rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
 
@@ -201,7 +202,7 @@ class OperatorTest(QueryTest):
             # this one would require adding compile() to InstrumentedScalarAttribute.  do we want this ?
             #(User.id, "users.id")
         ):
-            c = expr.compile(dialect=ansisql.ANSIDialect())
+            c = expr.compile(dialect=default.DefaultDialect())
             assert str(c) == compare, "%s != %s" % (str(c), compare)
             
             
index 3120185d59d9c01598bc7a3f8b746a215bcaaf7b..a8b642b9b9fd388f59c7ae131913176fc4f8fe7f 100644 (file)
@@ -179,7 +179,7 @@ class ConstraintTest(AssertMixin):
             capt.append(repr(context.parameters))
             ex(context)
         connection._Connection__execute = proxy
-        schemagen = testbase.db.dialect.schemagenerator(connection)
+        schemagen = testbase.db.dialect.schemagenerator(testbase.db.dialect, connection)
         schemagen.traverse(events)
         
         assert capt[0].strip().startswith('CREATE TABLE events')
index d79af577f2c6ebe3c30a4df60617659f424e0ed5..9bd97b3054b45b05df7ef6bd8761353f7ddcfc47 100644 (file)
@@ -1,6 +1,7 @@
 import testbase
 from sqlalchemy import *
 from testlib import *
+from sqlalchemy.sql.visitors import *
 
 class TraversalTest(AssertMixin):
     """test ClauseVisitor's traversal, particularly its ability to copy and modify
@@ -213,7 +214,7 @@ class ClauseTest(SQLCompileTest):
         self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
 
     def test_clause_adapter(self):
-        from sqlalchemy import sql_util
+        from sqlalchemy.sql import util as sql_util
         
         t1alias = t1.alias('t1alias')
         
index 553a3a3bc3c9723ddd723aada093aeddcd4d21f8..dee76428df8042f21e0dc6a3711f17cd1bb92d7a 100644 (file)
@@ -78,7 +78,7 @@ class LongLabelsTest(PersistTest):
       # this is the test that fails if the "max identifier length" is shorter than the 
       # length of the actual columns created, because the column names get truncated.
       # if you try to separate "physical columns" from "labels", and only truncate the labels,
-      # the ansisql.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code,
+      # the compiler.DefaultCompiler.visit_select() logic which auto-labels columns in a subquery (for the purposes of sqlite compat) breaks the code,
       # since it is creating "labels" on the fly but not affecting derived columns, which think they are
       # still "physical"
       q = table1.select(table1.c.this_is_the_primarykey_column == 4).alias('foo')
index 4f569f1c02d8b6741da989115b50d19e3af692fd..8ec3190b4d883cf799f2ddae103342a8cce1a7fd 100644 (file)
@@ -1,7 +1,8 @@
 import testbase
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exceptions, ansisql
+from sqlalchemy import exceptions
+from sqlalchemy.engine import default
 from testlib import *
 
 
@@ -166,14 +167,14 @@ class QueryTest(PersistTest):
         assert len(r) == 1
 
     def test_bindparam_detection(self):
-        dialect = ansisql.ANSIDialect(default_paramstyle='qmark')
-        prep = lambda q: dialect.compile(sql.text(q)).string
+        dialect = default.DefaultDialect(default_paramstyle='qmark')
+        prep = lambda q: str(sql.text(q).compile(dialect=dialect))
 
         def a_eq(got, wanted):
             if got != wanted:
                 print "Wanted %s" % wanted
                 print "Received %s" % got
-            self.assert_(got == wanted)
+            self.assert_(got == wanted, got)
 
         a_eq(prep('select foo'), 'select foo')
         a_eq(prep("time='12:30:00'"), "time='12:30:00'")
index ad25619df9c5ae4695f8aee936a1ce5b0c955c2a..0c414af3a66720d76b610a4d2569d1ef15c476d3 100644 (file)
@@ -1,7 +1,7 @@
 import testbase
 from sqlalchemy import *
 from testlib import *
-
+from sqlalchemy.sql import compiler
 
 class QuoteTest(PersistTest):
     def setUpAll(self):
@@ -98,10 +98,10 @@ class QuoteTest(PersistTest):
    
 
 class PreparerTest(PersistTest):
-    """Test the db-agnostic quoting services of ANSIIdentifierPreparer."""
+    """Test the db-agnostic quoting services of IdentifierPreparer."""
 
     def test_unformat(self):
-        prep = ansisql.ANSIIdentifierPreparer(None)
+        prep = compiler.IdentifierPreparer(None)
         unformat = prep.unformat_identifiers
 
         def a_eq(have, want):
@@ -120,7 +120,7 @@ class PreparerTest(PersistTest):
         a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz'])
 
     def test_unformat_custom(self):
-        class Custom(ansisql.ANSIIdentifierPreparer):
+        class Custom(compiler.IdentifierPreparer):
             def __init__(self, dialect):
                 super(Custom, self).__init__(dialect, initial_quote='`',
                                              final_quote='`')
index c497fbcbd778e2400d32a71e3ceeb8378034fbee..4ffb4c5916bf53311325369b84763e702bdb7d31 100644 (file)
@@ -195,7 +195,7 @@ class ColumnsTest(AssertMixin):
         )
 
         for aCol in testTable.c:
-            self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol))
+            self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).get_column_specification(aCol))
         
 class UnicodeTest(AssertMixin):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""
index 6830fb63c9d713cf540d21eecafca4b3a408a99d..58fd7c0d10adca60e3a59045861e49cbbe8e4261 100644 (file)
@@ -161,7 +161,8 @@ class ExecutionContextWrapper(object):
             if params is not None and isinstance(params, list) and len(params) == 1:
                 params = params[0]
             
-            if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+            from sqlalchemy.sql.util import ClauseParameters
+            if isinstance(ctx.compiled_parameters, ClauseParameters):
                 parameters = ctx.compiled_parameters.get_original_dict()
             elif isinstance(ctx.compiled_parameters, list):
                 parameters = [p.get_original_dict() for p in ctx.compiled_parameters]