]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- more quoting fixes for [ticket:450]...quoting more aggressive (but still skips...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Feb 2007 03:12:27 +0000 (03:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Feb 2007 03:12:27 +0000 (03:12 +0000)
- got mysql to have "format" as default paramstyle even if mysql module not available, allows unit tests
to pass in non-mysql system for [ticket:457].  all the dialects should be changed to pass in their usual
paramstyle.

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/sql/defaults.py
test/sql/quote.py

index 40ea5b00ac109ef6ca53767f96798e9901b25e72..794c48439449b6ed7ec654e3161d8c0864140c5a 100644 (file)
@@ -850,12 +850,13 @@ class ANSIIdentifierPreparer(object):
         self.initial_quote = initial_quote
         self.final_quote = final_quote or self.initial_quote
         self.omit_schema = omit_schema
+        self._is_quoted_regexp = re.compile(r"""^['"%s].+['"%s]$""" % (self.initial_quote, self.final_quote))
         self.__strings = {}
     def _escape_identifier(self, value):
         """escape an identifier.
         
         subclasses should override this to provide database-dependent escaping behavior."""
-        return value.replace('"', '""')
+        return value.replace("'", "''")
     
     def _quote_identifier(self, value):
         """quote an identifier.
@@ -872,6 +873,10 @@ class ANSIIdentifierPreparer(object):
         # some tests would need to be rewritten if this is done.
         #return value.upper()
     
+    def _is_quoted(self, ident):
+        """return true if the given identifier is already quoted"""
+        return self._is_quoted_regexp.match(ident)
+        
     def _reserved_words(self):
         return RESERVED_WORDS
 
@@ -884,10 +889,11 @@ class ANSIIdentifierPreparer(object):
     def _requires_quotes(self, value, case_sensitive):
         """return true if the given identifier requires quoting."""
         return \
-            value in self._reserved_words() \
+            not self._is_quoted(value) and \
+            (value in self._reserved_words() \
             or (value[0] in self._illegal_initial_characters()) \
             or bool(len([x for x in str(value) if x not in self._legal_characters()])) \
-            or (case_sensitive and value.lower() != value)
+            or (case_sensitive and not value.islower()))
     
     def __generic_obj_format(self, obj, ident):
         if getattr(obj, 'quote', False):
@@ -897,13 +903,13 @@ class ANSIIdentifierPreparer(object):
             try:
                 return self.__strings[(ident, case_sens)]
             except KeyError:
-                if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
+                if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident.islower())):
                     self.__strings[(ident, case_sens)] = self._quote_identifier(ident)
                 else:
                     self.__strings[(ident, case_sens)] = ident
                 return self.__strings[(ident, case_sens)]
         else:
-            if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
+            if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident.islower())):
                 return self._quote_identifier(ident)
             else:
                 return ident
@@ -934,17 +940,10 @@ class ANSIIdentifierPreparer(object):
         """Prepare a quoted column name """
         # TODO: isinstance alert !  get ColumnClause and Column to better
         # differentiate themselves
-        if isinstance(column, schema.SchemaItem):
-            if use_table:
-                return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name)
-            else:
-                return self.__generic_obj_format(column, column.name)
+        if use_table:
+            return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name)
         else:
-            # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
-            if use_table:
-                return self.format_table(column.table, use_schema=False) + "." + column.name
-            else:
-                return column.name
+            return self.__generic_obj_format(column, column.name)
             
     def format_column_with_table(self, column):
         """Prepare a quoted column name with table name"""
index d30751fb4def07b5f60560ed3bc0e1654281a548..c09d02756c85dbefb9509da3e9be2b8b295afe9a 100644 (file)
@@ -17,6 +17,7 @@ try:
     import MySQLdb.constants.CLIENT as CLIENT_FLAGS
 except:
     mysql = None
+    CLIENT_FLAGS = None
 
 def kw_colspec(self, spec):
     if self.unsigned:
@@ -256,7 +257,7 @@ class MySQLDialect(ansisql.ANSIDialect):
             self.module = mysql
         else:
             self.module = module
-        ansisql.ANSIDialect.__init__(self, **kwargs)
+        ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
 
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port'])
@@ -274,7 +275,8 @@ class MySQLDialect(ansisql.ANSIDialect):
         # TODO: what about options like "ssl", "cursorclass" and "conv" ?
 
         client_flag = opts.get('client_flag', 0)
-        client_flag |= CLIENT_FLAGS.FOUND_ROWS
+        if CLIENT_FLAGS is not None:
+            client_flag |= CLIENT_FLAGS.FOUND_ROWS
         opts['client_flag'] = client_flag
 
         return [[], opts]
index 06409377cdbcca3d165f21bf7171a0905abca1fb..2f6283b5d367d7c068445b263c2b2d0d233cc9dc 100644 (file)
@@ -24,14 +24,13 @@ class PoolConnectionProvider(base.ConnectionProvider):
         
 class DefaultDialect(base.Dialect):
     """default implementation of Dialect"""
-    def __init__(self, convert_unicode=False, encoding='utf-8', **kwargs):
+    def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs):
         self.convert_unicode = convert_unicode
         self.supports_autoclose_results = True
         self.encoding = encoding
         self.positional = False
-        self.paramstyle = 'named'
         self._ischema = None
-        self._figure_paramstyle()
+        self._figure_paramstyle(default=default_paramstyle)
     def create_execution_context(self):
         return DefaultExecutionContext(self)
     def type_descriptor(self, typeobj):
@@ -90,14 +89,14 @@ class DefaultDialect(base.Dialect):
                     parameters = parameters.get_raw_dict()
         return parameters
 
-    def _figure_paramstyle(self, paramstyle=None):
+    def _figure_paramstyle(self, paramstyle=None, default='named'):
         db = self.dbapi()
         if paramstyle is not None:
             self._paramstyle = paramstyle
         elif db is not None:
             self._paramstyle = db.paramstyle
         else:
-            self._paramstyle = 'named'
+            self._paramstyle = default
 
         if self._paramstyle == 'named':
             self.positional=False
index 2be808ba616a8caa89b759b5aab3cab67fa31463..44ba85453aed819abd7d6b7474b7023191b5d0b8 100644 (file)
@@ -482,7 +482,7 @@ class Column(SchemaItem, sql._ColumnClause):
         """redirect the 'case_sensitive' accessor to use the ultimate parent column which created
         this one."""
         return self.__originating_column._get_case_sensitive()
-    case_sensitive = property(_case_sens)
+    case_sensitive = property(_case_sens, lambda s,v:None)
     
     def accept_schema_visitor(self, visitor, traverse=True):
         """traverses the given visitor to this Column's default and foreign key object,
index 53cb6b97715cfab8b92f7d86a9a9b97d05c34465..823c8afc84bb5b8875acbff754e8869a653feca8 100644 (file)
@@ -219,10 +219,10 @@ def label(name, obj):
     """returns a _Label object for the given selectable, used in the column list for a select statement."""
     return _Label(name, obj)
     
-def column(text, table=None, type=None):
+def column(text, table=None, type=None, **kwargs):
     """returns a textual column clause, relative to a table.  this is also the primitive version of
     a schema.Column which is a subclass. """
-    return _ColumnClause(text, table, type)
+    return _ColumnClause(text, table, type, **kwargs)
 
 def table(name, *columns):
     """returns a table clause.  this is a primitive version of the schema.Table object, which is a subclass
@@ -1199,7 +1199,7 @@ class Alias(FromClause):
                 alias = alias[0:15]
             alias = alias + "_" + hex(random.randint(0, 65535))[2:]
         self.name = alias
-        self.case_sensitive = getattr(baseselectable, "case_sensitive", alias.lower() != alias)
+        self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
     def supports_execution(self):
         return self.original.supports_execution()    
     def _locate_oid_column(self):
@@ -1233,7 +1233,7 @@ class _Label(ColumnElement):
         while isinstance(obj, _Label):
             obj = obj.obj
         self.obj = obj
-        self.case_sensitive = getattr(obj, "case_sensitive", name.lower() != name)
+        self.case_sensitive = getattr(obj, "case_sensitive", True)
         self.type = sqltypes.to_instance(type)
         obj.parens=True
     key = property(lambda s: s.name)
@@ -1251,12 +1251,13 @@ legal_characters = util.Set(string.ascii_letters + string.digits + '_')
 class _ColumnClause(ColumnElement):
     """represents a textual column clause in a SQL statement.  May or may not
     be bound to an underlying Selectable."""
-    def __init__(self, text, selectable=None, type=None, _is_oid=False):
+    def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True):
         self.key = self.name = text
         self.table = selectable
         self.type = sqltypes.to_instance(type)
         self._is_oid = _is_oid
         self.__label = None
+        self.case_sensitive = case_sensitive
     def _get_label(self):
         if self.__label is None:
             if self.table is not None and self.table.named_with_column():
index 9d426ab0d2c1c38b298fbd87d8c9f6ef883bb117..d018c5efbf5bb1e67bf15b2d90f649af745d680b 100644 (file)
@@ -176,6 +176,7 @@ class SequenceTest(PersistTest):
         
         metadata.create_all()
     
+    @testbase.supported('postgres', 'oracle')
     def testseqnonpk(self):
         """test sequences fire off as defaults on non-pk columns"""
         sometable.insert().execute(name="somename")
index 607a595d3de33a9e1435b0a1939f4bc3477705b3..403ae2d426b9b7074280c8f68db8d625c498ae26 100644 (file)
@@ -97,10 +97,16 @@ class QuoteTest(PersistTest):
         x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias")
         assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"'''
 
+        # note that 'foo' and 'FooCol' are literals already quoted
         x = select([sql.column("'foo'").label("somelabel")], from_obj=[table]).alias("AnAlias")
         x = x.select()
+        #print x
         assert str(x) == '''SELECT "AnAlias".somelabel \nFROM (SELECT 'foo' AS somelabel \nFROM "ImATable") AS "AnAlias"'''
         
+        x = select([sql.column("'FooCol'").label("SomeLabel")], from_obj=[table])
+        x = x.select()
+        assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")'''
+        
     def testlabelsnocase(self):
         metadata = MetaData()
         table1 = Table('SomeCase1', metadata,