]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Housekeeping.
authorJason Kirtland <jek@discorporate.us>
Wed, 22 Aug 2007 08:33:09 +0000 (08:33 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 22 Aug 2007 08:33:09 +0000 (08:33 +0000)
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/interfaces.py

index 6dc0d605772bc565628524892e27c4bf8790c780..3eb0b30d058b3dc6cc49ef039cdcdf226ed9b333 100644 (file)
@@ -1390,8 +1390,8 @@ class MySQLDialect(default.DefaultDialect):
             opts['client_flag'] = client_flag
         return [[], opts]
 
-    def create_execution_context(self, *args, **kwargs):
-        return MySQLExecutionContext(self, *args, **kwargs)
+    def create_execution_context(self, connection, **kwargs):
+        return MySQLExecutionContext(self, connection, **kwargs)
 
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
@@ -1405,8 +1405,7 @@ class MySQLDialect(default.DefaultDialect):
     def schemadropper(self, *args, **kwargs):
         return MySQLSchemaDropper(self, *args, **kwargs)
 
-    def do_executemany(self, cursor, statement, parameters,
-                       context=None, **kwargs):
+    def do_executemany(self, cursor, statement, parameters, context=None):
         rowcount = cursor.executemany(statement, parameters)
         if context is not None:
             context._rowcount = rowcount
@@ -1414,7 +1413,7 @@ class MySQLDialect(default.DefaultDialect):
     def supports_unicode_statements(self):
         return True
                 
-    def do_execute(self, cursor, statement, parameters, **kwargs):
+    def do_execute(self, cursor, statement, parameters, context=None):
         cursor.execute(statement, parameters)
 
     def do_commit(self, connection):
@@ -1782,8 +1781,7 @@ class MySQLCompiler(compiler.DefaultCompiler):
 #       creation of foreign key constraints fails."
 
 class MySQLSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, override_pk=False,
-                                 first_pk=False):
+    def get_column_specification(self, column, first_pk=False):
         """Builds column DDL."""
         
         colspec = [self.preparer.format_column(column),
@@ -1950,7 +1948,7 @@ class MySQLSchemaReflector(object):
             warnings.warn(RuntimeWarning(
                 "Did not recognize type '%s' of column '%s'" %
                 (type_, name)))
-            coltype = sqltypes.NULLTYPE
+            col_type = sqltypes.NULLTYPE
         
         # Column type positional arguments eg. varchar(32)
         if args is None or args == '':
@@ -2066,15 +2064,16 @@ class MySQLSchemaReflector(object):
             if ref_key in table.metadata.tables:
                 ref_table = table.metadata.tables[ref_key]
             else:
-                ref_table = schema.Table(ref_name, table.metadata,
-                                         schema=ref_schema,
-                                         autoload=True, autoload_with=connection)
+                ref_table = schema.Table(
+                    ref_name, table.metadata, schema=ref_schema,
+                    autoload=True, autoload_with=connection)
 
             ref_names = spec['foreign']
             if not util.Set(ref_names).issubset(
                 util.Set([c.name for c in ref_table.c])):
                 raise exceptions.InvalidRequestError(
-                    "Foreign key columns (%s) are not present on foreign table" %
+                    "Foreign key columns (%s) are not present on "
+                    "foreign table %s" %
                     (', '.join(ref_names), ref_table.fullname()))
             ref_columns = [ref_table.c[name] for name in ref_names]
 
@@ -2112,13 +2111,12 @@ class MySQLSchemaReflector(object):
         self._pr_options = []
         self._re_options_util = {}
 
-        _initial, _final = (self.preparer.initial_quote,
-                            self.preparer.final_quote)
+        _final = self.preparer.final_quote
         
         quotes = dict(zip(('iq', 'fq', 'esc_fq'),
                           [re.escape(s) for s in
                            (self.preparer.initial_quote,
-                            self.preparer.final_quote,
+                            _final,
                             self.preparer._escape_identifier(_final))]))
 
         self._pr_name = _pr_compile(
@@ -2387,7 +2385,7 @@ class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
     def _quote_free_identifiers(self, *ids):
         """Unilaterally identifier-quote any number of strings."""
 
-        return tuple([self.quote_identifier(id) for id in ids if id is not None])
+        return tuple([self.quote_identifier(i) for i in ids if i is not None])
 
 
 class MySQLIdentifierPreparer(_MySQLIdentifierPreparer):
index 86146af4d584cfdbbca55b180ab8a40023f0e3b0..d3bd765727534783039703cc36d9de61652931a6 100644 (file)
@@ -42,6 +42,8 @@ class SLSmallInteger(sqltypes.Smallinteger):
         return "SMALLINT"
 
 class DateTimeMixin(object):
+    __format__ = "%Y-%m-%d %H:%M:%S"
+
     def bind_processor(self, dialect):
         def process(value):
             if isinstance(value, basestring):
@@ -63,7 +65,7 @@ class DateTimeMixin(object):
             (value, microsecond) = value.split('.')
             microsecond = int(microsecond)
         except ValueError:
-            (value, microsecond) = (value, 0)
+            microsecond = 0
         return time.strptime(value, self.__format__)[0:6] + (microsecond,)
 
 class SLDateTime(DateTimeMixin,sqltypes.DateTime):
@@ -225,11 +227,8 @@ class SQLiteDialect(default.DefaultDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
-    def create_execution_context(self, **kwargs):
-        return SQLiteExecutionContext(self, **kwargs)
-
-    def last_inserted_ids(self):
-        return self.context.last_inserted_ids
+    def create_execution_context(self, connection, **kwargs):
+        return SQLiteExecutionContext(self, connection, **kwargs)
 
     def oid_column_name(self, column):
         return "oid"
@@ -255,13 +254,13 @@ class SQLiteDialect(default.DefaultDialect):
             row = c.fetchone()
             if row is None:
                 break
-            #print "row! " + repr(row)
+
             found_table = True
-            (name, type, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5])
+            (name, type_, nullable, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4] is not None, row[5])
             name = re.sub(r'^\"|\"$', '', name)
             if include_columns and name not in include_columns:
                 continue
-            match = re.match(r'(\w+)(\(.*?\))?', type)
+            match = re.match(r'(\w+)(\(.*?\))?', type_)
             if match:
                 coltype = match.group(1)
                 args = match.group(2)
@@ -269,7 +268,6 @@ class SQLiteDialect(default.DefaultDialect):
                 coltype = "VARCHAR"
                 args = ''
 
-            #print "coltype: " + repr(coltype) + " args: " + repr(args)
             try:
                 coltype = pragma_names[coltype]
             except KeyError:
@@ -278,7 +276,6 @@ class SQLiteDialect(default.DefaultDialect):
                 
             if args is not None:
                 args = re.findall(r'(\d+)', args)
-                #print "args! " +repr(args)
                 coltype = coltype(*[int(a) for a in args])
 
             colargs= []
@@ -335,14 +332,14 @@ class SQLiteDialect(default.DefaultDialect):
                 if row is None:
                     break
                 cols.append(row[2])
-                col = table.columns[row[2]]
+
 
 class SQLiteCompiler(compiler.DefaultCompiler):
-    def visit_cast(self, cast):
+    def visit_cast(self, cast, **kwargs):
         if self.dialect.supports_cast:
             return super(SQLiteCompiler, self).visit_cast(cast)
         else:
-            if self.select_stack:
+            if self.stack and self.stack[-1].get('select'):
                 # not sure if we want to set the typemap here...
                 self.typemap.setdefault("CAST", cast.type)
             return self.process(cast.clause)
index 958cb74801b9a5004f0b4fd7010affc022e6830d..d3d96a711a41f8c505ef8f5310d76adb79941705 100644 (file)
@@ -247,12 +247,12 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def do_executemany(self, cursor, statement, parameters):
+    def do_executemany(self, cursor, statement, parameters, context=None):
         """Provide an implementation of *cursor.executemany(statement, parameters)*."""
 
         raise NotImplementedError()
 
-    def do_execute(self, cursor, statement, parameters):
+    def do_execute(self, cursor, statement, parameters, context=None):
         """Provide an implementation of *cursor.execute(statement, parameters)*."""
 
         raise NotImplementedError()
index 0ab0eb82be6b1cc05d2818234c251182e1876d84..3322753bb6961c8fef3ad5dd10f5e94fa9df3d6e 100644 (file)
@@ -52,8 +52,8 @@ class DefaultDialect(base.Dialect):
         
         return {}
     
-    def create_execution_context(self, **kwargs):
-        return DefaultExecutionContext(self, **kwargs)
+    def create_execution_context(self, connection, **kwargs):
+        return DefaultExecutionContext(self, connection, **kwargs)
 
     def type_descriptor(self, typeobj):
         """Provide a database-specific ``TypeEngine`` object, given
@@ -108,10 +108,10 @@ class DefaultDialect(base.Dialect):
     def do_release_savepoint(self, connection, name):
         connection.execute(expression.ReleaseSavepointClause(name))
 
-    def do_executemany(self, cursor, statement, parameters, **kwargs):
+    def do_executemany(self, cursor, statement, parameters, context=None):
         cursor.executemany(statement, parameters)
 
-    def do_execute(self, cursor, statement, parameters, **kwargs):
+    def do_execute(self, cursor, statement, parameters, context=None):
         cursor.execute(statement, parameters)
 
     def is_disconnect(self, e):
index fef82626ca9f448199c9ee26c1338deb659aeb4e..ae2aeed690eb05461ae3fb460ba621afc7a355a8 100644 (file)
@@ -42,7 +42,7 @@ class PoolListener(object):
     providing implementations for the hooks you'll be using.
     """
 
-    def connect(dbapi_con, con_record):
+    def connect(self, dbapi_con, con_record):
         """Called once for each new DB-API connection or Pool's ``creator()``.
 
         dbapi_con
@@ -54,7 +54,7 @@ class PoolListener(object):
           
         """
 
-    def checkout(dbapi_con, con_record, con_proxy):
+    def checkout(self, dbapi_con, con_record, con_proxy):
         """Called when a connection is retrieved from the Pool.
 
         dbapi_con
@@ -73,7 +73,7 @@ class PoolListener(object):
         using the new connection.
         """
 
-    def checkin(dbapi_con, con_record):
+    def checkin(self, dbapi_con, con_record):
         """Called when a connection returns to the pool.
 
         Note that the connection may be closed, and may be None if the