]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- method call removal
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Aug 2007 21:50:59 +0000 (21:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Aug 2007 21:50:59 +0000 (21:50 +0000)
18 files changed:
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/util.py
test/ext/activemapper.py
test/orm/unitofwork.py
test/sql/labels.py
test/sql/rowcount.py

index f901ebf53549a5871bb7529b0bad61bc52833589..4994e3309e2cee1d1232dd2536190d660b289edf 100644 (file)
@@ -6,7 +6,7 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import random
-from sqlalchemy import sql, schema, ansisql, types, exceptions, pool
+from sqlalchemy import sql, schema, types, exceptions, pool
 from sqlalchemy.sql import compiler
 import sqlalchemy.engine.default as default
 
@@ -159,7 +159,7 @@ class AccessExecutionContext(default.DefaultExecutionContext):
 
 
 const, daoEngine = None, None
-class AccessDialect(ansisql.ANSIDialect):
+class AccessDialect(compiler.DefaultDialect):
     colspecs = {
         types.Unicode : AcUnicode,
         types.Integer : AcInteger,
@@ -176,6 +176,9 @@ class AccessDialect(ansisql.ANSIDialect):
         types.TIMESTAMP: AcTimeStamp,
     }
 
+    supports_sane_rowcount = False
+
+
     def type_descriptor(self, typeobj):
         newobj = types.adapt_type(typeobj, self.colspecs)
         return newobj
@@ -211,9 +214,6 @@ class AccessDialect(ansisql.ANSIDialect):
     def create_execution_context(self, *args, **kwargs):
         return AccessExecutionContext(self, *args, **kwargs)
 
-    def supports_sane_rowcount(self):
-        return False
-
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
@@ -416,7 +416,7 @@ class AccessSchemaDropper(compiler.SchemaDropper):
         self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, index.name))
         self.execute()
 
-class AccessDefaultRunner(ansisql.ANSIDefaultRunner):
+class AccessDefaultRunner(compiler.DefaultRunner):
     pass
 
 class AccessIdentifierPreparer(compiler.IdentifierPreparer):
index 9cccb53e85cbe27d71eb8da5fc478a338448e479..2a9bbb5bdbfcb517ee8e4b70f5bdddfc300fb5c2 100644 (file)
@@ -101,6 +101,9 @@ class FBExecutionContext(default.DefaultExecutionContext):
 
 
 class FBDialect(default.DefaultDialect):
+    supports_sane_rowcount = False
+    max_identifier_length = 31
+
     def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
 
@@ -133,12 +136,6 @@ class FBDialect(default.DefaultDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
-    def supports_sane_rowcount(self):
-        return False
-
-    def max_identifier_length(self):
-        return 31
-    
     def table_names(self, connection, schema):
         s = "SELECT R.RDB$RELATION_NAME FROM RDB$RELATIONS R"
         return [row[0] for row in connection.execute(s)]
@@ -408,12 +405,11 @@ RESERVED_WORDS = util.Set(
 
 
 class FBIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = RESERVED_WORDS
+    
     def __init__(self, dialect):
         super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True)
 
-    def _reserved_words(self):
-        return RESERVED_WORDS
-
 
 dialect = FBDialect
 dialect.statement_compiler = FBCompiler
index 45dfb0370e467a4052fd48c3eb1ab7959a751df0..67d31387d52eff90deb151182438aab654e1d47a 100644 (file)
@@ -205,6 +205,8 @@ class InfoExecutionContext(default.DefaultExecutionContext):
         return informix_cursor( self.connection.connection )
         
 class InfoDialect(default.DefaultDialect):
+    # for informix 7.31
+    max_identifier_length = 18
     
     def __init__(self, use_ansi=True,**kwargs):
         self.use_ansi = use_ansi
@@ -216,10 +218,6 @@ class InfoDialect(default.DefaultDialect):
         return informixdb
     dbapi = classmethod(dbapi)
 
-    def max_identifier_length( self ):
-        # for informix 7.31
-        return 18
-    
     def is_disconnect(self, e):
         if isinstance(e, self.dbapi.OperationalError):
             return 'closed the connection' in str(e) or 'connection not open' in str(e)
index 1985d2112c904ed650092253d2fad0eb1eba8f1c..03b276d4a2dd105f3a12b15e0ac73e8b1aba64b1 100644 (file)
@@ -472,10 +472,6 @@ class MSSQLDialect(default.DefaultDialect):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
-    # this is only implemented in the dbapi-specific subclasses
-    def supports_sane_rowcount(self):
-        raise NotImplementedError()
-
     def get_default_schema_name(self, connection):
         return self.schema_name
 
@@ -665,6 +661,8 @@ class MSSQLDialect(default.DefaultDialect):
             table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
 
 class MSSQLDialect_pymssql(MSSQLDialect):
+    supports_sane_rowcount = False
+
     def import_dbapi(cls):
         import pymssql as module
         # pymmsql doesn't have a Binary method.  we use string
@@ -683,12 +681,6 @@ class MSSQLDialect_pymssql(MSSQLDialect):
         super(MSSQLDialect_pymssql, self).__init__(**params)
         self.use_scope_identity = True
 
-    def supports_sane_rowcount(self):
-        return False
-
-    def max_identifier_length(self):
-        return 30
-
     def do_rollback(self, connection):
         # pymssql throws an error on repeated rollbacks. Ignore it.
         # TODO: this is normal behavior for most DBs.  are we sure we want to ignore it ?
@@ -746,6 +738,9 @@ class MSSQLDialect_pymssql(MSSQLDialect):
 ##        r.fetch_array()
 
 class MSSQLDialect_pyodbc(MSSQLDialect):
+    supports_sane_rowcount = False
+    # PyODBC unicode is broken on UCS-4 builds
+    supports_unicode_statements = sys.maxunicode == 65535
     
     def __init__(self, **params):
         super(MSSQLDialect_pyodbc, self).__init__(**params)
@@ -771,14 +766,6 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
     ischema_names['smalldatetime'] = MSDate_pyodbc
     ischema_names['datetime'] = MSDateTime_pyodbc
 
-    def supports_sane_rowcount(self):
-        return False
-
-    def supports_unicode_statements(self):
-        """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
-        # PyODBC unicode is broken on UCS-4 builds
-        return sys.maxunicode == 65535
-
     def make_connect_string(self, keys):
         if 'dsn' in keys:
             connectors = ['dsn=%s' % keys['dsn']]
@@ -818,6 +805,9 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
             context._last_inserted_ids = [int(row[0])]
 
 class MSSQLDialect_adodbapi(MSSQLDialect):
+    supports_sane_rowcount = True
+    supports_unicode_statements = True
+
     def import_dbapi(cls):
         import adodbapi as module
         return module
@@ -831,13 +821,6 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
     ischema_names['nvarchar'] = AdoMSNVarchar
     ischema_names['datetime'] = MSDateTime_adodbapi
 
-    def supports_sane_rowcount(self):
-        return True
-
-    def supports_unicode_statements(self):
-        """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
-        return True
-
     def make_connect_string(self, keys):
         connectors = ["Provider=SQLOLEDB"]
         if 'port' in keys:
index 41c6ec70f3beb48cd397453019f135bcc7d100a2..6dc0d605772bc565628524892e27c4bf8790c780 100644 (file)
@@ -1332,6 +1332,12 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
 class MySQLDialect(default.DefaultDialect):
     """Details of the MySQL dialect.  Not used directly in application code."""
 
+    supports_alter = True
+    supports_unicode_statements = False
+    # identifiers are 64, however aliases can be 255...
+    max_identifier_length = 255
+    supports_sane_rowcount = True
+
     def __init__(self, use_ansiquotes=False, **kwargs):
         self.use_ansiquotes = use_ansiquotes
         kwargs.setdefault('default_paramstyle', 'format')
@@ -1390,13 +1396,6 @@ class MySQLDialect(default.DefaultDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
-    # identifiers are 64, however aliases can be 255...
-    def max_identifier_length(self):
-        return 255;
-
-    def supports_sane_rowcount(self):
-        return True
-
     def compiler(self, statement, bindparams, **kwargs):
         return MySQLCompiler(statement, bindparams, dialect=self, **kwargs)
 
@@ -2369,13 +2368,12 @@ MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector)
 
 class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
     """MySQL-specific schema identifier configuration."""
+
+    reserved_words = RESERVED_WORDS
     
     def __init__(self, dialect, **kw):
         super(_MySQLIdentifierPreparer, self).__init__(dialect, **kw)
 
-    def _reserved_words(self):
-        return RESERVED_WORDS
-
     def _fold_identifier_case(self, value):
         # TODO: determine MySQL's case folding rules
         #
index 580850818331fbcd65bf3295c0e3b0fc1a8e5993..9b3ffbf23fffad2649e90525d29a7953a8b64de7 100644 (file)
@@ -232,6 +232,11 @@ class OracleExecutionContext(default.DefaultExecutionContext):
         return base.ResultProxy(self)
 
 class OracleDialect(default.DefaultDialect):
+    supports_alter = True
+    supports_unicode_statements = False
+    max_identifier_length = 30
+    supports_sane_rowcount = True
+
     def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs):
         default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs)
         self.use_ansi = use_ansi
@@ -291,13 +296,6 @@ class OracleDialect(default.DefaultDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
-    def supports_unicode_statements(self):
-        """indicate whether the DB-API can receive SQL statements as Python unicode strings"""
-        return False
-
-    def max_identifier_length(self):
-        return 30
-        
     def oid_column_name(self, column):
         if not isinstance(column.table, (sql.TableClause, sql.Select)):
             return None
index 29d84ad4db7f4168fcbb4e3b55d1f975ff906f6d..2a4d230cd582fe043e32fe0f89ab476a570d0753 100644 (file)
@@ -223,6 +223,11 @@ class PGExecutionContext(default.DefaultExecutionContext):
         super(PGExecutionContext, self).post_exec()
         
 class PGDialect(default.DefaultDialect):
+    supports_alter = True
+    supports_unicode_statements = False
+    max_identifier_length = 63
+    supports_sane_rowcount = True
+
     def __init__(self, use_oids=False, server_side_cursors=False, **kwargs):
         default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
         self.use_oids = use_oids
@@ -241,13 +246,9 @@ class PGDialect(default.DefaultDialect):
         opts.update(url.query)
         return ([], opts)
 
-
     def create_execution_context(self, *args, **kwargs):
         return PGExecutionContext(self, *args, **kwargs)
 
-    def max_identifier_length(self):
-        return 63
-        
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
index c2aced4d03f00fe0004dee4c0a7aa194ea23f203..8618bfc3ed8c6b3d02e05ea6cc22396823352746 100644 (file)
@@ -174,6 +174,8 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
         return SELECT_REGEXP.match(self.statement)
         
 class SQLiteDialect(default.DefaultDialect):
+    supports_alter = False
+    supports_unicode_statements = True
     
     def __init__(self, **kwargs):
         default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs)
@@ -199,9 +201,6 @@ class SQLiteDialect(default.DefaultDialect):
     def server_version_info(self, connection):
         return self.dbapi.sqlite_version_info
 
-    def supports_alter(self):
-        return False
-
     def create_connect_args(self, url):
         filename = url.database or ':memory:'
 
@@ -220,9 +219,6 @@ class SQLiteDialect(default.DefaultDialect):
     def create_execution_context(self, **kwargs):
         return SQLiteExecutionContext(self, **kwargs)
 
-    def supports_unicode_statements(self):
-        return True
-
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
index 2e75d358c4ae6f414dad301dc0522aebf1327ef6..ef875a638c241a5d1a17781c6cafeee2e7a694c8 100644 (file)
@@ -62,6 +62,19 @@ class Dialect(object):
     preparer
       a [sqlalchemy.sql.compiler#IdentifierPreparer] class used to
       quote identifiers.
+
+    supports_alter
+      ``True`` if the database supports ``ALTER TABLE``.
+
+    max_identifier_length
+      The maximum length of identifier names.
+
+    supports_unicode_statements
+      Indicate whether the DB-API can receive SQL statements as Python unicode strings
+
+    supports_sane_rowcount
+      Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements.
+
     """
 
     def create_connect_args(self, url):
@@ -119,31 +132,6 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def supports_alter(self):
-        """Return ``True`` if the database supports ``ALTER TABLE``."""
-        raise NotImplementedError()
-
-    def max_identifier_length(self):
-        """Return the maximum length of identifier names.
-
-        Returns ``None`` if no limit.
-        """
-
-        return None
-
-    def supports_unicode_statements(self):
-        """Indicate whether the DB-API can receive SQL statements as Python unicode strings"""
-
-        raise NotImplementedError()
-
-    def supports_sane_rowcount(self):
-        """Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements.
-
-        This was needed for MySQL which had non-standard behavior of rowcount,
-        but this issue has since been resolved.
-        """
-
-        raise NotImplementedError()
 
 
     def server_version_info(self, connection):
@@ -521,9 +509,6 @@ class Connectable(object):
     def execute(self, object, *multiparams, **params):
         raise NotImplementedError()
 
-    engine = util.NotImplProperty("The Engine which this Connectable is associated with.")
-    dialect = util.NotImplProperty("Dialect which this Connectable is associated with.")
-
 class Connection(Connectable):
     """Provides high-level functionality for a wrapped DB-API connection.
 
@@ -1020,14 +1005,13 @@ class Engine(Connectable):
     def __init__(self, pool, dialect, url, echo=None):
         self.pool = pool
         self.url = url
-        self._dialect=dialect
+        self.dialect=dialect
         self.echo = echo
+        self.engine = self
         self.logger = logging.instance_logger(self)
         self._should_log = logging.is_info_enabled(self.logger)
 
     name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'], doc="String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``.")
-    engine = property(lambda s:s)
-    dialect = property(lambda s:s._dialect, doc="the [sqlalchemy.engine#Dialect] in use by this engine.")
     echo = logging.echo_property()
     
     def __repr__(self):
index ec8d8d5a7effe5d5620c318e5f1d5120e4e7aa0d..50fac430bf89494c34cafaf1caf7897c476355a2 100644 (file)
@@ -26,6 +26,10 @@ class DefaultDialect(base.Dialect):
     statement_compiler = compiler.DefaultCompiler
     preparer = compiler.IdentifierPreparer
     defaultrunner = base.DefaultRunner
+    supports_alter = True
+    supports_unicode_statements = False
+    max_identifier_length = 9999
+    supports_sane_rowcount = True
 
     def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
         self.convert_unicode = convert_unicode
@@ -33,7 +37,13 @@ class DefaultDialect(base.Dialect):
         self.positional = False
         self._ischema = None
         self.dbapi = dbapi
-        self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle)
+        if paramstyle is not None:
+            self.paramstyle = paramstyle
+        elif self.dbapi is not None:
+            self.paramstyle = self.dbapi.paramstyle
+        else:
+            self.paramstyle = default_paramstyle
+        self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
         self.identifier_preparer = self.preparer(self)
     
     def dbapi_type_map(self):
@@ -56,23 +66,10 @@ class DefaultDialect(base.Dialect):
             typeobj = typeobj()
         return typeobj
 
-    def supports_unicode_statements(self):
-        """True if DB-API can receive SQL statements as Python Unicode."""
-        return False
-
-    def max_identifier_length(self):
-        # TODO: probably raise this and fill out db modules better
-        return 9999
-
-    def supports_alter(self):
-        return True
         
     def oid_column_name(self, column):
         return None
 
-    def supports_sane_rowcount(self):
-        return True
-
     def do_begin(self, connection):
         """Implementations might want to put logic here for turning
         autocommit on/off, etc.
@@ -120,32 +117,6 @@ class DefaultDialect(base.Dialect):
     def is_disconnect(self, e):
         return False
         
-    def _set_paramstyle(self, style):
-        self._paramstyle = style
-        self._figure_paramstyle(style)
-
-    paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
-
-    def _figure_paramstyle(self, paramstyle=None, default='named'):
-        if paramstyle is not None:
-            self._paramstyle = paramstyle
-        elif self.dbapi is not None:
-            self._paramstyle = self.dbapi.paramstyle
-        else:
-            self._paramstyle = default
-
-        if self._paramstyle == 'named':
-            self.positional=False
-        elif self._paramstyle == 'pyformat':
-            self.positional=False
-        elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric':
-            # for positional, use pyformat internally, ANSICompiler will convert
-            # to appropriate character upon compilation
-            self.positional = True
-        else:
-            raise exceptions.InvalidRequestError(
-                "Unsupported paramstyle '%s'" % self._paramstyle)
-
     def _get_ischema(self):
         if self._ischema is None:
             import sqlalchemy.databases.information_schema as ischema
@@ -185,7 +156,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         else:
             self.statement = None
             
-        if self.statement is not None and not dialect.supports_unicode_statements():
+        if self.statement is not None and not dialect.supports_unicode_statements:
             self.statement = self.statement.encode(self.dialect.encoding)
             
         self.cursor = self.create_cursor()
@@ -200,7 +171,7 @@ class DefaultExecutionContext(base.ExecutionContext):
     
     def __encode_param_keys(self, params):
         """apply string encoding to the keys of dictionary-based bind parameters"""
-        if self.dialect.positional or self.dialect.supports_unicode_statements():
+        if self.dialect.positional or self.dialect.supports_unicode_statements:
             return params
         else:
             def proc(d):
@@ -215,7 +186,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                 return proc(params)
 
     def __convert_compiled_params(self, parameters):
-        encode = not self.dialect.supports_unicode_statements()
+        encode = not self.dialect.supports_unicode_statements
         # the bind params are a CompiledParams object.  but all the
         # DB-API's hate that object (or similar).  so convert it to a
         # clean dictionary/list/tuple of dictionary/tuple of list
@@ -274,7 +245,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             return self.cursor.rowcount
 
     def supports_sane_rowcount(self):
-        return self.dialect.supports_sane_rowcount()
+        return self.dialect.supports_sane_rowcount
 
     def last_inserted_ids(self):
         return self._last_inserted_ids
index 99ca2389bec31c9c2e76e10b659714e5054fcd6c..b6f345be2254e2720812b2bcc0433598a13ba9f0 100644 (file)
@@ -58,27 +58,21 @@ class SchemaItem(object):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
-    def _derived_metadata(self):
-        """Return the the MetaData to which this item is bound."""
-
-        return None
-
     def _get_bind(self, raiseerr=False):
         """Return the engine or None if no engine."""
 
         if raiseerr:
-            m = self._derived_metadata()
+            m = self.metadata
             e = m and m.bind or None
             if e is None:
                 raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
             else:
                 return e
         else:
-            m = self._derived_metadata()
+            m = self.metadata
             return m and m.bind or None
 
 
-    metadata = property(lambda s:s._derived_metadata())
     bind = property(lambda s:s._get_bind())
     
 def _get_table_key(name, schema):
@@ -228,7 +222,7 @@ class Table(SchemaItem, expression.TableClause):
 
         """
         super(Table, self).__init__(name)
-        self._metadata = metadata
+        self.metadata = metadata
         self.schema = kwargs.pop('schema', None)
         self.indexes = util.Set()
         self.constraints = util.Set()
@@ -263,9 +257,6 @@ class Table(SchemaItem, expression.TableClause):
         self.constraints.add(pk)
     primary_key = property(lambda s:s._primary_key, _set_primary_key)
 
-    def _derived_metadata(self):
-        return self._metadata
-
     def __repr__(self):
         return "Table(%s)" % ', '.join(
             [repr(self.name)] + [repr(self.metadata)] +
@@ -286,11 +277,11 @@ class Table(SchemaItem, expression.TableClause):
         constraint._set_parent(self)
 
     def _get_parent(self):
-        return self._metadata
+        return self.metadata
 
     def _set_parent(self, metadata):
         metadata.tables[_get_table_key(self.name, self.schema)] = self
-        self._metadata = metadata
+        self.metadata = metadata
 
     def get_children(self, column_collections=True, schema_visitor=False, **kwargs):
         if not schema_visitor:
@@ -476,9 +467,6 @@ class Column(SchemaItem, expression._ColumnClause):
         else:
             return self.encodedname
 
-    def _derived_metadata(self):
-        return self.table.metadata
-
     def _get_bind(self):
         return self.table.bind
 
@@ -515,6 +503,7 @@ class Column(SchemaItem, expression._ColumnClause):
         return self.table
 
     def _set_parent(self, table):
+        self.metadata = table.metadata
         if getattr(self, 'table', None) is not None:
             raise exceptions.ArgumentError("this Column already has a table!")
         if not self._is_oid:
@@ -699,20 +688,14 @@ class DefaultGenerator(SchemaItem):
 
     def __init__(self, for_update=False, metadata=None):
         self.for_update = for_update
-        self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
-
-    def _derived_metadata(self):
-        try:
-            return self.column.table.metadata
-        except AttributeError:
-            return self._metadata
+        self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
 
     def _get_parent(self):
         return getattr(self, 'column', None)
 
     def _set_parent(self, column):
         self.column = column
-        self._metadata = self.column.table.metadata
+        self.metadata = self.column.table.metadata
         if self.for_update:
             self.column.onupdate = self
         else:
@@ -957,9 +940,6 @@ class Index(SchemaItem):
         self.unique = kwargs.pop('unique', False)
         self._init_items(*columns)
 
-    def _derived_metadata(self):
-        return self.table.metadata
-
     def _init_items(self, *args):
         for column in args:
             self.append_column(column)
@@ -969,6 +949,7 @@ class Index(SchemaItem):
 
     def _set_parent(self, table):
         self.table = table
+        self.metadata = table.metadata
         table.indexes.add(self)
 
     def append_column(self, column):
@@ -1053,6 +1034,7 @@ class MetaData(SchemaItem):
 
         self.tables = {}
         self.bind = bind
+        self.metadata = self
         if reflect:
             if not bind:
                 raise exceptions.ArgumentError(
@@ -1239,9 +1221,6 @@ class MetaData(SchemaItem):
             bind = self._get_bind(raiseerr=True)
         bind.drop(self, checkfirst=checkfirst, tables=tables)
 
-    def _derived_metadata(self):
-        return self
-
     def _get_bind(self, raiseerr=False):
         if not self.is_bound():
             if raiseerr:
index 59eb3cdb39d354a14a21be7e97c8bbf8eb1872c5..59964178ccce360916864556fa6b8fcc2f07c16e 100644 (file)
@@ -421,9 +421,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         
         anonname = ANONYMOUS_LABEL.sub(self._process_anon, name)
 
-        if len(anonname) > self.dialect.max_identifier_length():
+        if len(anonname) > self.dialect.max_identifier_length:
             counter = self.generated_ids.get(ident_class, 1)
-            truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
+            truncname = name[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:]
             self.generated_ids[ident_class] = counter + 1
         else:
             truncname = anonname
@@ -515,7 +515,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
                     l = co.label(labelname)
                     inner_columns.add(self.process(l))
                 else:
-                    self.traverse(co)
                     inner_columns.add(self.process(co))
             else:
                 l = self.label_select_column(select, co)
@@ -620,20 +619,16 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         # for inserts, this includes Python-side defaults, columns with sequences for dialects
         # that support sequences, and primary key columns for dialects that explicitly insert
         # pre-generated primary key values
-        required_cols = util.Set()
-        class DefaultVisitor(schema.SchemaVisitor):
-            def visit_column(s, cd):
-                if c.primary_key and self.uses_sequences_for_inserts():
-                    required_cols.add(c)
-            def visit_column_default(s, cd):
-                required_cols.add(c)
-            def visit_sequence(s, seq):
-                if self.uses_sequences_for_inserts():
-                    required_cols.add(c)
-        vis = DefaultVisitor()
-        for c in insert_stmt.table.c:
-            if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
-                vis.traverse(c)
+        required_cols = [
+            c for c in insert_stmt.table.c
+            if \
+                isinstance(c, schema.SchemaItem) and \
+                (self.parameters is None or self.parameters.get(c.key, None) is None) and \
+                (
+                    ((c.primary_key or isinstance(c.default, schema.Sequence)) and self.uses_sequences_for_inserts()) or 
+                    isinstance(c.default, schema.ColumnDefault)
+                )
+        ]
 
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt, required_cols)
@@ -646,14 +641,12 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         
         # search for columns who will be required to have an explicit bound value.
         # for updates, this includes Python-side "onupdate" defaults.
-        required_cols = util.Set()
-        class OnUpdateVisitor(schema.SchemaVisitor):
-            def visit_column_onupdate(s, cd):
-                required_cols.add(c)
-        vis = OnUpdateVisitor()
-        for c in update_stmt.table.c:
-            if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
-                vis.traverse(c)
+        required_cols = [c for c in update_stmt.table.c 
+            if
+            isinstance(c, schema.SchemaItem) and \
+            (self.parameters is None or self.parameters.get(c.key, None) is None) and
+            isinstance(c.onupdate, schema.ColumnDefault)
+        ]
 
         self.isupdate = True
         colparams = self._get_colparams(update_stmt, required_cols)
@@ -681,11 +674,6 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
             self.binds[col.key] = bindparam
             return self.bindparam_string(self._truncate_bindparam(bindparam))
 
-        def create_clause_param(col, value):
-            self.traverse(value)
-            self.inline_params.add(col)
-            return self.process(value)
-
         self.inline_params = util.Set()
 
         def to_col(key):
@@ -704,25 +692,28 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         if self.parameters is None:
             parameters = {}
         else:
-            parameters = dict([(to_col(k), v) for k, v in self.parameters.iteritems()])
+            parameters = dict([(getattr(k, 'key', k), v) for k, v in self.parameters.iteritems()])
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
-                parameters.setdefault(to_col(k), v)
+                parameters.setdefault(getattr(k, 'key', k), v)
 
         for col in required_cols:
-            parameters.setdefault(col, None)
+            parameters.setdefault(col.key, None)
 
         # create a list of column assignment clauses as tuples
         values = []
         for c in stmt.table.columns:
-            if c in parameters:
-                value = parameters[c]
-                if sql._is_literal(value):
-                    value = create_bind_param(c, value)
-                else:
-                    value = create_clause_param(c, value)
-                values.append((c, value))
+            if c.key in parameters:
+                value = parameters[c.key]
+            else:
+                continue
+            if sql._is_literal(value):
+                value = create_bind_param(c, value)
+            else:
+                self.inline_params.add(c)
+                value = self.process(value)
+            values.append((c, value))
         
         return values
 
@@ -778,7 +769,7 @@ class SchemaGenerator(DDLBase):
         collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
         for table in collection:
             self.traverse_single(table)
-        if self.dialect.supports_alter():
+        if self.dialect.supports_alter:
             for alterable in self.find_alterables(collection):
                 self.add_foreignkey(alterable)
 
@@ -853,7 +844,7 @@ class SchemaGenerator(DDLBase):
         self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
 
     def visit_foreign_key_constraint(self, constraint):
-        if constraint.use_alter and self.dialect.supports_alter():
+        if constraint.use_alter and self.dialect.supports_alter:
             return
         self.append(", \n\t ")
         self.define_foreign_key(constraint)
@@ -909,7 +900,7 @@ class SchemaDropper(DDLBase):
 
     def visit_metadata(self, metadata):
         collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or  self.dialect.has_table(self.connection, t.name, schema=t.schema))]
-        if self.dialect.supports_alter():
+        if self.dialect.supports_alter:
             for alterable in self.find_alterables(collection):
                 self.drop_foreignkey(alterable)
         for table in collection:
@@ -936,6 +927,12 @@ class SchemaDropper(DDLBase):
 class IdentifierPreparer(object):
     """Handle quoting and case-folding of identifiers based on options."""
 
+    reserved_words = RESERVED_WORDS
+
+    legal_characters = LEGAL_CHARACTERS
+
+    illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
+
     def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
         """Construct a new ``IdentifierPreparer`` object.
 
@@ -995,21 +992,12 @@ class IdentifierPreparer(object):
         # some tests would need to be rewritten if this is done.
         #return value.upper()
 
-    def _reserved_words(self):
-        return RESERVED_WORDS
-
-    def _legal_characters(self):
-        return LEGAL_CHARACTERS
-
-    def _illegal_initial_characters(self):
-        return ILLEGAL_INITIAL_CHARACTERS
-
     def _requires_quotes(self, value):
         """Return True if the given identifier requires quoting."""
         return \
-            value in self._reserved_words() \
-            or (value[0] in self._illegal_initial_characters()) \
-            or bool(len([x for x in unicode(value) if x not in self._legal_characters()])) \
+            value in self.reserved_words \
+            or (value[0] in self.illegal_initial_characters) \
+            or bool(len([x for x in unicode(value) if x not in self.legal_characters])) \
             or (value.lower() != value)
 
     def __generic_obj_format(self, obj, ident):
index 64fc9b3b437943c55765e0deb0c731dee05d125c..ea87a8c4f4acc1f1db9060751760c7e5f3fd723b 100644 (file)
@@ -1007,7 +1007,7 @@ class ClauseElement(object):
             compiler = dialect.statement_compiler(dialect, self, parameters=parameters)
         compiler.compile()
         return compiler
-
+    
     def __str__(self):
         return unicode(self.compile()).encode('ascii', 'backslashreplace')
 
@@ -1618,13 +1618,6 @@ class FromClause(Selectable):
             else:
                 raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.name))
 
-    def _get_exported_attribute(self, name):
-        try:
-            return getattr(self, name)
-        except AttributeError:
-            self._export_columns()
-            return getattr(self, name)
-
     def _clone_from_clause(self):
         # delete all the "generated" collections of columns for a
         # newly cloned FromClause, so that they will be re-derived
@@ -1635,11 +1628,20 @@ class FromClause(Selectable):
             if hasattr(self, attr):
                 delattr(self, attr)
 
-    columns = property(lambda s:s._get_exported_attribute('_columns'))
-    c = property(lambda s:s._get_exported_attribute('_columns'))
-    primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
-    foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys'))
-    original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc=\
+    def _expr_attr_func(name):
+        def attr(self):
+            try:
+                return getattr(self, name)
+            except AttributeError:
+                self._export_columns()
+                return getattr(self, name)
+        return attr
+
+    columns = property(_expr_attr_func('_columns'))
+    c = property(_expr_attr_func('_columns'))
+    primary_key = property(_expr_attr_func('_primary_key'))
+    foreign_keys = property(_expr_attr_func('_foreign_keys'))
+    original_columns = property(_expr_attr_func('_orig_cols'), doc=\
         """A dictionary mapping an original Table-bound 
         column to a proxied column in this FromClause.
         """)
@@ -1659,7 +1661,6 @@ class FromClause(Selectable):
         """
 
         if hasattr(self, '_columns') and columns is None:
-            # TODO: put a mutex here ?  this is a key place for threading probs
             return
         self._columns = ColumnCollection()
         self._primary_key = ColumnSet()
@@ -1753,9 +1754,11 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         self.shortname = shortname or key
         self.unique = unique
         self.isoutparam = isoutparam
-        type_ = sqltypes.to_instance(type_)
-        if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map:
-            self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)])
+
+        if type_ is None:
+            self.type = self.type_map.get(type(value), sqltypes.NullType)()
+        elif isinstance(type_, type):
+            self.type = type_()
         else:
             self.type = type_
 
@@ -1764,7 +1767,8 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         str : sqltypes.String,
         unicode : sqltypes.Unicode,
         int : sqltypes.Integer,
-        float : sqltypes.Numeric
+        float : sqltypes.Numeric,
+        type(None):sqltypes.NullType
     }
 
     def _get_from_objects(self, **modifiers):
index ba6458f2a039a7ed02cfca06c6661dca41bef4a2..d31be6a3635b4722b3f406a022ac16e0efeaa36b 100644 (file)
@@ -268,7 +268,11 @@ class OrderedProperties(object):
     def __setattr__(self, key, object):
         self._data[key] = object
 
-    _data = property(lambda s:s.__dict__['_data'])
+    def __getstate__(self):
+        return self._data
+    
+    def __setstate__(self, value):
+        self.__dict__['_data'] = value
 
     def __getattr__(self, key):
         try:
index e28c72cd7315459dc1bd6e198e7760f2d27adf39..7e266030c8752884cc4b83f94e153239dfbe1ba0 100644 (file)
@@ -175,7 +175,7 @@ class testcase(PersistTest):
             objectstore.context.current = s1
             objectstore.flush()
             # Only dialects with a sane rowcount can detect the ConcurrentModificationError
-            if testbase.db.dialect.supports_sane_rowcount():
+            if testbase.db.dialect.supports_sane_rowcount:
                 assert False
         except exceptions.ConcurrentModificationError:
             pass
index fd7af0421603fc555a4bd7d92508a768ee9c84a7..d689f1703e97f7cb5f9b61e6f4f2e8c786f64469 100644 (file)
@@ -78,7 +78,7 @@ class VersioningTest(ORMTest):
             success = True
 
         # Only dialects with a sane rowcount can detect the ConcurrentModificationError
-        if testbase.db.dialect.supports_sane_rowcount():
+        if testbase.db.dialect.supports_sane_rowcount:
             assert success
         
         s.close()
@@ -96,7 +96,7 @@ class VersioningTest(ORMTest):
         except exceptions.ConcurrentModificationError, e:
             #print e
             success = True
-        if testbase.db.dialect.supports_sane_rowcount():
+        if testbase.db.dialect.supports_sane_rowcount:
             assert success
         
     @engines.close_open_connections
index dee76428df8042f21e0dc6a3711f17cd1bb92d7a..6588c4da4c861c380e34f1a5915c6096829339b1 100644 (file)
@@ -27,7 +27,7 @@ class LongLabelsTest(PersistTest):
         metadata.create_all()
         
         maxlen = testbase.db.dialect.max_identifier_length
-        testbase.db.dialect.max_identifier_length = lambda: 29
+        testbase.db.dialect.max_identifier_length = 29
         
     def tearDown(self):
         table1.delete().execute()
@@ -89,7 +89,7 @@ class LongLabelsTest(PersistTest):
         """test that a primary key column compiled as the 'oid' column gets proper length truncation"""
         from sqlalchemy.databases import postgres
         dialect = postgres.PGDialect()
-        dialect.max_identifier_length = lambda: 30
+        dialect.max_identifier_length = 30
         tt = table1.select(use_labels=True).alias('foo')
         x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect)
         #print x
index cf9ba30d974ffaa04c44faab8e89c5364af6d1c1..095f79200d97a2bc37d27dc39512d80e94a70156 100644 (file)
@@ -47,21 +47,21 @@ class FoundRowsTest(AssertMixin):
         # WHERE matches 3, 3 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='Z')
-        if testbase.db.dialect.supports_sane_rowcount():
+        if testbase.db.dialect.supports_sane_rowcount:
             assert r.rowcount == 3
 
     def test_update_rowcount2(self):
         # WHERE matches 3, 0 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='C')
-        if testbase.db.dialect.supports_sane_rowcount():
+        if testbase.db.dialect.supports_sane_rowcount:
             assert r.rowcount == 3
 
     def test_delete_rowcount(self):
         # WHERE matches 3, 3 rows deleted
         department = employees_table.c.department
         r = employees_table.delete(department=='C').execute()
-        if testbase.db.dialect.supports_sane_rowcount():
+        if testbase.db.dialect.supports_sane_rowcount:
             assert r.rowcount == 3
 
 if __name__ == '__main__':