]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- simplify MySQLIdentifierPreparer into standard pattern,
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Aug 2009 04:48:00 +0000 (04:48 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Aug 2009 04:48:00 +0000 (04:48 +0000)
thus allowing easy subclassing
- move % sign logic for MySQLIdentifierPreparer into MySQLdb dialect
- paramterize the escape/unescape quote char in IdentifierPreparer
- cut out MySQLTableDefinitionParser cruft

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py

index b36cebd37eb0e14570ca29b5966c39a294498789..fc3236ba92cfef55c3a277c6b14fb1c4386cf5a5 100644 (file)
@@ -1109,7 +1109,6 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
         super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
 
     def _escape_identifier(self, value):
-        #TODO: determine MSSQL's escaping rules
         return value
 
     def quote_schema(self, schema, force=True):
index 570d1a79be2b3eab5f894041f9c6967d7c38c10e..4700a7b4ffc4f80952eb5ba2769674ba5bf4e4ad 100644 (file)
@@ -1608,8 +1608,29 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         return "BOOL"
         
 
+class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
+
+    reserved_words = RESERVED_WORDS
+
+    def __init__(self, dialect, server_ansiquotes=False, **kw):
+        if not server_ansiquotes:
+            quote = "`"
+        else:
+            quote = '"'    
+
+        super(MySQLIdentifierPreparer, self).__init__(
+                                                dialect, 
+                                                initial_quote=quote, 
+                                                escape_quote=quote)
+
+    def _quote_free_identifiers(self, *ids):
+        """Unilaterally identifier-quote any number of strings."""
+
+        return tuple([self.quote_identifier(i) for i in ids if i is not None])
+
 class MySQLDialect(default.DefaultDialect):
     """Details of the MySQL dialect.  Not used directly in application code."""
+    
     name = 'mysql'
     supports_alter = True
     # identifiers are 64, however aliases can be 255...
@@ -1625,6 +1646,7 @@ class MySQLDialect(default.DefaultDialect):
     ddl_compiler = MySQLDDLCompiler
     type_compiler = MySQLTypeCompiler
     ischema_names = ischema_names
+    preparer = MySQLIdentifierPreparer
     
     def __init__(self, use_ansiquotes=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
@@ -1750,12 +1772,10 @@ class MySQLDialect(default.DefaultDialect):
         self._server_casing = self._detect_casing(connection)
         self._server_collations = self._detect_collations(connection)
         self._server_ansiquotes = self._detect_ansiquotes(connection)
-            
         if self._server_ansiquotes:
-            self.preparer = MySQLANSIIdentifierPreparer
-        else:
-            self.preparer = MySQLIdentifierPreparer
-        self.identifier_preparer = self.preparer(self)
+            # if ansiquotes == True, build a new IdentifierPreparer
+            # with the new setting
+            self.identifier_preparer = self.preparer(self, server_ansiquotes=self._server_ansiquotes)
 
     @reflection.cache
     def get_schema_names(self, connection, **kw):
@@ -1894,19 +1914,26 @@ class MySQLDialect(default.DefaultDialect):
                         schema, 
                         info_cache=kw.get('info_cache', None)
                     )
+    
+    @util.memoized_property
+    def _tabledef_parser(self):
+        """return the MySQLTableDefinitionParser, generate if needed.
+        
+        The deferred creation ensures that the dialect has 
+        retrieved server version information first.
+        
+        """
+        if (self.server_version_info < (4, 1) and self._server_ansiquotes):
+            # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
+            preparer = self.preparer(self, server_ansiquotes=False)
+        else:
+            preparer = self.identifier_preparer
+        return MySQLTableDefinitionParser(self, preparer)
         
     @reflection.cache
     def _setup_parser(self, connection, table_name, schema=None, **kw):
         charset = self._connection_charset
-        try:
-            parser = self.parser
-        except AttributeError:
-            preparer = self.identifier_preparer
-            if (self.server_version_info < (4, 1) and
-                self._server_ansiquotes):
-                # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
-                preparer = MySQLIdentifierPreparer(self)
-            self.parser = parser = MySQLTableDefinitionParser(self, preparer)
+        parser = self._tabledef_parser
         full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
             schema, table_name))
         sql = self._show_create_table(connection, None, charset,
@@ -2501,43 +2528,6 @@ class _DecodingRowProxy(object):
             return item
 
 
-class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
-    """MySQL-specific schema identifier configuration."""
-
-    reserved_words = RESERVED_WORDS
-
-    def __init__(self, dialect, **kw):
-        super(_MySQLIdentifierPreparer, self).__init__(dialect, **kw)
-
-    def _quote_free_identifiers(self, *ids):
-        """Unilaterally identifier-quote any number of strings."""
-
-        return tuple([self.quote_identifier(i) for i in ids if i is not None])
-
-    def _escape_identifier(self, value):
-        value = value.replace('"', '""')
-        return value.replace('%', '%%')
-
-
-class MySQLIdentifierPreparer(_MySQLIdentifierPreparer):
-    """Traditional MySQL-specific schema identifier configuration."""
-
-    def __init__(self, dialect):
-        super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`")
-        
-    def _escape_identifier(self, value):
-        value = value.replace('`', '``')
-        return value.replace('%', '%%')
-
-    def _unescape_identifier(self, value):
-        return value.replace('``', '`')
-
-
-class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer):
-    """ANSI_QUOTES MySQL schema identifier configuration."""
-
-    pass
-
 def _pr_compile(regex, cleanup=None):
     """Prepare a 2-tuple of compiled regex and callable."""
 
index 6ecfc4b845c332afb62c52585255b8893031d04a..a7764add7329f6eb4cdfe3027c5fa629eeb9c579 100644 (file)
@@ -24,7 +24,7 @@ import decimal
 import re
 
 from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext,
-                                            MySQLCompiler, NUMERIC, _NumericType)
+                                            MySQLCompiler, MySQLIdentifierPreparer, NUMERIC, _NumericType)
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
@@ -66,6 +66,11 @@ class _MySQLdbNumeric(_DecimalType, NUMERIC):
 class _MySQLdbDecimal(_DecimalType, DECIMAL):
     pass
 
+class MySQL_mysqldbIdentifierPreparer(MySQLIdentifierPreparer):
+    
+    def _escape_identifier(self, value):
+        value = value.replace(self.escape_quote, self.escape_to_quote)
+        return value.replace("%", "%%")
 
 class MySQL_mysqldb(MySQLDialect):
     driver = 'mysqldb'
@@ -76,7 +81,8 @@ class MySQL_mysqldb(MySQLDialect):
     default_paramstyle = 'format'
     execution_ctx_cls = MySQL_mysqldbExecutionContext
     statement_compiler = MySQL_mysqldbCompiler
-
+    preparer = MySQL_mysqldbIdentifierPreparer
+    
     colspecs = util.update_copy(
         MySQLDialect.colspecs,
         {
index c89ae16bd11bc78a0e88b3f53c818a117f741a31..57f2161c137079fa80436865209fa11b7c143e78 100644 (file)
@@ -433,6 +433,7 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
 
     def _escape_identifier(self, value):
         value = value.replace('"', '""')
+        # TODO: might want to move this to psycopg2 + pg8000 individually
         return value.replace('%', '%%')
         
 class PGInspector(reflection.Inspector):
index 403ec968bbcd18afe48218eb219ee3006df62762..02824a5f4033bd328d6d3e1f0a8e1fa4f6ef2382 100644 (file)
@@ -1259,7 +1259,7 @@ class IdentifierPreparer(object):
 
     illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
 
-    def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
+    def __init__(self, dialect, initial_quote='"', final_quote=None, escape_quote='"', omit_schema=False):
         """Construct a new ``IdentifierPreparer`` object.
 
         initial_quote
@@ -1276,6 +1276,8 @@ class IdentifierPreparer(object):
         self.dialect = dialect
         self.initial_quote = initial_quote
         self.final_quote = final_quote or self.initial_quote
+        self.escape_quote = escape_quote
+        self.escape_to_quote = self.escape_quote * 2
         self.omit_schema = omit_schema
         self._strings = {}
         
@@ -1286,7 +1288,7 @@ class IdentifierPreparer(object):
         escaping behavior.
         """
 
-        return value.replace('"', '""')
+        return value.replace(self.escape_quote, self.escape_to_quote)
 
     def _unescape_identifier(self, value):
         """Canonicalize an escaped identifier.
@@ -1295,7 +1297,7 @@ class IdentifierPreparer(object):
         unescaping behavior that reverses _escape_identifier.
         """
 
-        return value.replace('""', '"')
+        return value.replace(self.escape_to_quote, self.escape_quote)
 
     def quote_identifier(self, value):
         """Quote an identifier.