]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Quoting information is now passed along
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 May 2012 22:40:55 +0000 (18:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 May 2012 22:40:55 +0000 (18:40 -0400)
    from a Column with quote=True when generating
    a same-named bound parameter to the bindparam()
    object, as is the case in generated INSERT and UPDATE
    statements, so that unknown reserved names can
    be fully supported.  [ticket:2437]

CHANGES
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/test_oracle.py

diff --git a/CHANGES b/CHANGES
index 1a2eb37f98da88faf8e17453ecc5bf16ae0ea6f9..5c1e52fdeffff7b85eec0a52c67769cf1653f82c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -330,6 +330,14 @@ CHANGES
     expressions.  [ticket:2467]
     also in 0.7.7.
 
+- oracle
+  - [bug] Quoting information is now passed along
+    from a Column with quote=True when generating
+    a same-named bound parameter to the bindparam()
+    object, as is the case in generated INSERT and UPDATE 
+    statements, so that unknown reserved names can 
+    be fully supported.  [ticket:2437]
+
 - extensions
   - [removed] The SQLSoup extension is removed from 
     SQLAlchemy, and is now an external project.
index 5001acca34886e15a19f131c953bdb70532d5dd4..8f1f0d812cbf51e8acfc1ec3926876de1eb52868 100644 (file)
@@ -296,8 +296,9 @@ class _OracleRowid(oracle.ROWID):
         return dbapi.ROWID
 
 class OracleCompiler_cx_oracle(OracleCompiler):
-    def bindparam_string(self, name):
-        if self.preparer._bindparam_requires_quotes(name):
+    def bindparam_string(self, name, quote=None):
+        if quote is True or quote is not False and \
+            self.preparer._bindparam_requires_quotes(name):
             quoted_name = '"%s"' % name
             self._quoted_bind_names[name] = quoted_name
             return OracleCompiler.bindparam_string(self, quoted_name)
index a58da176cd5e80f9e887043944749b3a5c8f9c0f..05cc70aba874feb3451a03f9f6a58d4da830e4c8 100644 (file)
@@ -686,7 +686,7 @@ class SQLCompiler(engine.Compiled):
 
         self.binds[bindparam.key] = self.binds[name] = bindparam
 
-        return self.bindparam_string(name)
+        return self.bindparam_string(name, quote=bindparam.quote)
 
     def render_literal_bindparam(self, bindparam, **kw):
         value = bindparam.value
@@ -756,7 +756,7 @@ class SQLCompiler(engine.Compiled):
         self.anon_map[derived] = anonymous_counter + 1
         return derived + "_" + str(anonymous_counter)
 
-    def bindparam_string(self, name):
+    def bindparam_string(self, name, quote=None):
         if self.positional:
             self.positiontup.append(name)
             return self.bindtemplate % {
@@ -1206,7 +1206,8 @@ class SQLCompiler(engine.Compiled):
 
     def _create_crud_bind_param(self, col, value, required=False):
         bindparam = sql.bindparam(col.key, value, 
-                            type_=col.type, required=required)
+                            type_=col.type, required=required,
+                            quote=col.quote)
         bindparam._is_crud = True
         return bindparam._compiler_dispatch(self)
 
index a0f0bab6cac6b3fd7a2044694ba8fd1dc54fa162..6e9ddbd5a26ff9a8a63b31d93971bdbabe384a55 100644 (file)
@@ -987,7 +987,8 @@ def table(name, *columns):
     """
     return TableClause(name, *columns)
 
-def bindparam(key, value=None, type_=None, unique=False, required=False, callable_=None):
+def bindparam(key, value=None, type_=None, unique=False, required=False, 
+                        quote=None, callable_=None):
     """Create a bind parameter clause with the given key.
 
         :param key:
@@ -1024,15 +1025,19 @@ def bindparam(key, value=None, type_=None, unique=False, required=False, callabl
         :param required:
           a value is required at execution time.
 
+        :param quote:
+          True if this parameter name requires quoting and is not
+          currently known as a SQLAlchemy reserved word; this currently
+          only applies to the Oracle backend.
+
     """
     if isinstance(key, ColumnClause):
-        return _BindParamClause(key.name, value, type_=key.type, 
-                                callable_=callable_,
-                                unique=unique, required=required)
-    else:
-        return _BindParamClause(key, value, type_=type_, 
-                                callable_=callable_,
-                                unique=unique, required=required)
+        type_ = key.type
+        key = key.name
+    return _BindParamClause(key, value, type_=type_, 
+                            callable_=callable_,
+                            unique=unique, required=required,
+                            quote=quote)
 
 def outparam(key, type_=None):
     """Create an 'OUT' parameter for usage in functions (stored procedures),
@@ -2613,6 +2618,7 @@ class _BindParamClause(ColumnElement):
     def __init__(self, key, value, type_=None, unique=False, 
                             callable_=None,
                             isoutparam=False, required=False, 
+                            quote=None,
                             _compared_to_operator=None,
                             _compared_to_type=None):
         """Construct a _BindParamClause.
@@ -2648,6 +2654,11 @@ class _BindParamClause(ColumnElement):
           already has been located within the containing
           :class:`.ClauseElement`.
 
+        :param quote: 
+          True if this parameter name requires quoting and is not
+          currently known as a SQLAlchemy reserved word; this currently
+          only applies to the Oracle backend.
+
         :param required:
           a value is required at execution time.
 
@@ -2677,6 +2688,7 @@ class _BindParamClause(ColumnElement):
         self.callable = callable_
         self.isoutparam = isoutparam
         self.required = required
+        self.quote = quote
         if type_ is None:
             if _compared_to_type is not None:
                 self.type = \
index 9540fa9635a798d182fe56c07a7c850ae720875f..ab958f57f5450143b7283b69debcda6c424f139d 100644 (file)
@@ -47,10 +47,42 @@ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT numb
     def teardown_class(cls):
          testing.db.execute("DROP PROCEDURE foo")
 
+class QuotedBindRoundTripTest(fixtures.TestBase):
+
+    __only_on__ = 'oracle'
+
+    @testing.provide_metadata
+    def test_table_round_trip(self):
+        oracle.RESERVED_WORDS.remove('UNION')
+
+        metadata = self.metadata
+        table = Table("t1", metadata,
+            Column("option", Integer),
+            Column("plain", Integer, quote=True),
+            # test that quote works for a reserved word
+            # that the dialect isn't aware of when quote
+            # is set
+            Column("union", Integer, quote=True)
+        )
+        metadata.create_all()
+
+        table.insert().execute(
+            {"option":1, "plain":1, "union":1}
+        )
+        eq_(
+            testing.db.execute(table.select()).first(),
+            (1, 1, 1)
+        )
+        table.update().values(option=2, plain=2, union=2).execute()
+        eq_(
+            testing.db.execute(table.select()).first(),
+            (2, 2, 2)
+        )
+
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
-    __dialect__ = oracle.OracleDialect()
+    __dialect__ = oracle.dialect()
 
     def test_owner(self):
         meta = MetaData()
@@ -73,6 +105,26 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                                 "sometable.col1 AS col1, sometable.col2 "
                                 "AS col2 FROM sometable)")
 
+    def test_bindparam_quote(self):
+        """test that bound parameters take on quoting for reserved words,
+        column names quote flag enabled."""
+        # note: this is only in cx_oracle at the moment.  not sure
+        # what other hypothetical oracle dialects might need
+
+        self.assert_compile(
+            bindparam("option"), ':"option"'
+        )
+        self.assert_compile(
+            bindparam("plain"), ':plain'
+        )
+        t = Table("s", MetaData(), Column('plain', Integer, quote=True))
+        self.assert_compile(
+            t.insert().values(plain=5), 'INSERT INTO s ("plain") VALUES (:"plain")'
+        )
+        self.assert_compile(
+            t.update().values(plain=5), 'UPDATE s SET "plain"=:"plain"'
+        )
+
     def test_limit(self):
         t = table('sometable', column('col1'), column('col2'))
         s = select([t])