]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
A few fixes to the access dialect
authorPaul Johnston <paj@pajhome.org.uk>
Fri, 12 Oct 2007 23:39:28 +0000 (23:39 +0000)
committerPaul Johnston <paj@pajhome.org.uk>
Fri, 12 Oct 2007 23:39:28 +0000 (23:39 +0000)
lib/sqlalchemy/databases/access.py

index 9f4847c45c05e69df303620c33778996271c0581..c6e6107bfa70dc7b1a3f4dc9c5a2f4600f74048c 100644 (file)
@@ -7,10 +7,9 @@
 
 import random
 from sqlalchemy import sql, schema, types, exceptions, pool
-from sqlalchemy.sql import compiler
+from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, base
 
-
 class AcNumeric(types.Numeric):
     def result_processor(self, dialect):
         return None
@@ -149,11 +148,13 @@ class AccessExecutionContext(default.DefaultExecutionContext):
                         break
 
             if bool(tbl.has_sequence):
-                if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
-                    self.cursor.execute("SELECT @@identity AS lastrowid")
-                    row = self.cursor.fetchone()
-                    self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
-                    # print "LAST ROW ID", self._last_inserted_ids
+                # TBD: for some reason _last_inserted_ids doesn't exist here
+                # (but it does at corresponding point in mssql???)
+                #if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+                self.cursor.execute("SELECT @@identity AS lastrowid")
+                row = self.cursor.fetchone()
+                self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
+                # print "LAST ROW ID", self._last_inserted_ids
 
         super(AccessExecutionContext, self).post_exec()
 
@@ -177,7 +178,7 @@ class AccessDialect(default.DefaultDialect):
     }
 
     supports_sane_rowcount = False
-
+    supports_sane_multi_rowcount = False
 
     def type_descriptor(self, typeobj):
         newobj = types.adapt_type(typeobj, self.colspecs)
@@ -217,21 +218,6 @@ class AccessDialect(default.DefaultDialect):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return AccessCompiler(self, statement, bindparams, **kwargs)
-
-    def schemagenerator(self, *args, **kwargs):
-        return AccessSchemaGenerator(self, *args, **kwargs)
-
-    def schemadropper(self, *args, **kwargs):
-        return AccessSchemaDropper(self, *args, **kwargs)
-
-    def defaultrunner(self, connection, **kwargs):
-        return AccessDefaultRunner(connection, **kwargs)
-
-    def preparer(self):
-        return AccessIdentifierPreparer(self)
-
     def do_execute(self, cursor, statement, params, **kwargs):
         if params == {}:
             params = ()
@@ -254,7 +240,7 @@ class AccessDialect(default.DefaultDialect):
         except Exception, e:
             return False
 
-    def reflecttable(self, connection, table):        
+    def reflecttable(self, connection, table, include_columns):        
         # This is defined in the function, as it relies on win32com constants,
         # that aren't imported until dbapi method is called
         if not hasattr(self, 'ischema_names'):
@@ -364,13 +350,11 @@ class AccessCompiler(compiler.DefaultCompiler):
         """Access uses "mod" instead of "%" """
         return binary.operator == '%' and 'mod' or binary.operator
 
-    def visit_select(self, select):
-        """Label function calls, so they return a name in cursor.description"""
-        for i,c in enumerate(select._raw_columns):
-            if isinstance(c, sql._Function):
-                select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
-
-        super(AccessCompiler, self).visit_select(select)
+    def label_select_column(self, select, column):
+        if isinstance(column, expression._Function):
+            return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])        
+        else:
+            return super(AccessCompiler, self).label_select_column(select, column)
 
     function_rewrites =  {'current_date':       'now',
                           'current_timestamp':  'now',
@@ -418,9 +402,16 @@ class AccessDefaultRunner(base.DefaultRunner):
     pass
 
 class AccessIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = compiler.RESERVED_WORDS.copy()
+    reserved_words.update(['value', 'text'])
     def __init__(self, dialect):
         super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
 
 
 dialect = AccessDialect
 dialect.poolclass = pool.SingletonThreadPool
+dialect.statement_compiler = AccessCompiler
+dialect.schemagenerator = AccessSchemaGenerator
+dialect.schemadropper = AccessSchemaDropper
+dialect.preparer = AccessIdentifierPreparer
+dialect.defaultrunner = AccessDefaultRunner