]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added explicit bind parameters and column type maps to text type
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Jan 2006 00:42:07 +0000 (00:42 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Jan 2006 00:42:07 +0000 (00:42 +0000)
text type also parses :<string> into bind param objects
bind parameters convert their incoming type using engine.type_descriptor() methods
types.adapt_type() adjusted to not do extra work with incoming types, since the bind
param change will cause it to be called a lot more
added tests to new text type stuff, bind params, fixed some type tests
added basic docs for using text with binde params

doc/build/content/sqlconstruction.myt
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/types.py
test/select.py
test/types.py

index 679525df5344de0ce5ff78beac83ca1c18b380be..fd010be9797ef62265d9ea74e4dbfede5c33bb29 100644 (file)
@@ -744,8 +744,11 @@ SELECT * FROM (select user_id, user_name from users)
 select user_name from users
 {}
 </&>
-            # a straight text query like the one above is also available directly off the engine
-            # (though youre going to have to drop down to the DBAPI's style of bind params)
+            # or call text() off of the engine
+            engine.text("select user_name from users").execute()
+            
+            # execute off the engine directly - you must use the engine's native bind parameter
+            # style (i.e. named, pyformat, positional, etc.)
             <&formatting.myt:poplink&>db.execute(
                     "select user_name from users where user_id=:user_id", 
                     {'user_id':7}).execute()
@@ -755,6 +758,36 @@ select user_name from users where user_id=:user_id
 </&>
 
             
+        </&>
+
+        <&|doclib.myt:item, name="textual_binds", description="Using Bind Parameters in Text Blocks" &>
+        <p>Use the format <span class="codeline"><% ':<paramname>' |h %></span> to define bind parameters inside of a text block.  They will be converted to the appropriate format upon compilation:</p>
+        <&|formatting.myt:code &>
+            t = engine.text("select foo from mytable where lala=:hoho")
+            r = t.execute(hoho=7)
+        </&>        
+        <p>Bind parameters can also be explicit, which allows typing information to be added.  Just specify them as a list with
+        keys that match those inside the textual statement:</p>
+        <&|formatting.myt:code &>
+            t = engine.text("select foo from mytable where lala=:hoho", 
+                        bindparams=[bindparam('hoho', type=types.String)])
+            r = t.execute(hoho="im hoho")
+        </&>        
+        <p>Result-row type processing can be added via the <span class="codeline">typemap</span> argument, which 
+        is a dictionary of return columns mapped to types:</p>
+        <&|formatting.myt:code &>
+            # specify DateTime type for the 'foo' column in the result set
+            # sqlite, for example, uses result-row post-processing to construct dates
+            t = engine.text("select foo from mytable where lala=:hoho", 
+                    bindparams=[bindparam('hoho', type=types.String)],
+                    typemap={'foo':types.DateTime}
+                    )
+            r = t.execute(hoho="im hoho")
+            
+            # 'foo' is a datetime
+            year = r.fetchone()['foo'].year
+        </&>        
+        
         </&>
     </&>
     <&|doclib.myt:item, name="building", description="Building Select Objects" &>
index f0e0203e9b6c88d3c9ce22ad31ae39a5bb6ae32b..30060a93ce6daf5dc8b5b5a0b0054f82e0bfcdaf 100644 (file)
@@ -174,6 +174,8 @@ class ANSICompiler(sql.Compiled):
         else:
             self.strings[textclause] = textclause.text
         self.froms[textclause] = textclause.text
+        if textclause.typemap is not None:
+            self.typemap.update(textclause.typemap)
         
     def visit_null(self, null):
         self.strings[null] = 'NULL'
index 054c5853c78cc7b53915fd8e155f3f94616e59c7..a1e57f6b3439f668c1ba3db9631ea0dbd1040e78 100644 (file)
@@ -223,6 +223,10 @@ class SQLEngine(schema.SchemaEngine):
         if type(typeobj) is type:
             typeobj = typeobj()
         return typeobj
+
+    def text(self, text, *args, **kwargs):
+        """returns a sql.text() object for performing literal queries."""
+        return sql.text(text, engine=self, *args, **kwargs)
         
     def schemagenerator(self, proxy, **params):
         """returns a schema.SchemaVisitor instance that can generate schemas, when it is
index e92246a404a5127312561886d44b0098e7b992a7..3a0e32de13c6f6ca06300cdb2695cc7e5ee0ef44 100644 (file)
@@ -173,14 +173,14 @@ def bindparam(key, value = None, type=None):
     else:
         return BindParamClause(key, value, type=type)
 
-def text(text, engine=None):
+def text(text, engine=None, *args, **kwargs):
     """creates literal text to be inserted into a query.  
     
     When constructing a query from a select(), update(), insert() or delete(), using 
     plain strings for argument values will usually result in text objects being created
     automatically.  Use this function when creating textual clauses outside of other
     ClauseElement objects, or optionally wherever plain text is to be used."""
-    return TextClause(text, engine=engine)
+    return TextClause(text, engine=engine, *args, **kwargs)
 
 def null():
     """returns a Null object, which compiles to NULL in a sql statement."""
@@ -536,14 +536,17 @@ class FromClause(Selectable):
     
 class BindParamClause(ClauseElement, CompareMixin):
     """represents a bind parameter.  public constructor is the bindparam() function."""
-    def __init__(self, key, value, shortname = None, type = None):
+    def __init__(self, key, value, shortname=None, type=None):
         self.key = key
         self.value = value
         self.shortname = shortname
         self.type = type or sqltypes.NULLTYPE
-        # if passed a class as a type, convert to an instance
-        if isinstance(self.type, types.TypeType):
-            self.type = self.type()
+    def _get_convert_type(self, engine):
+        try:
+            return self._converted_type
+        except AttributeError:
+            self._converted_type = engine.type_descriptor(self.type)
+            return self._converted_type
     def accept_visitor(self, visitor):
         visitor.visit_bindparam(self)
     def _get_from_objects(self):
@@ -551,7 +554,7 @@ class BindParamClause(ClauseElement, CompareMixin):
     def hash_key(self):
         return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname))
     def typeprocess(self, value, engine):
-        return self.type.convert_bind_param(value, engine)
+        return self._get_convert_type(engine).convert_bind_param(value, engine)
     def compare(self, other):
         """compares this BindParamClause to the given clause.
         
@@ -570,12 +573,27 @@ class TextClause(ClauseElement):
     being specified as a bind parameter via the bindparam() method,
     since it provides more information about what it is, including an optional
     type, as well as providing comparison operations."""
-    def __init__(self, text = "", engine=None):
-        self.text = text
+    def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
         self.parens = False
         self._engine = engine
         self.id = id(self)
+        self.bindparams = {}
+        self.typemap = typemap
+        if typemap is not None:
+            for key in typemap.keys():
+                typemap[key] = engine.type_descriptor(typemap[key])
+        def repl(m):
+            self.bindparams[m.group(1)] = bindparam(m.group(1))
+            return self.engine.bindtemplate % m.group(1)
+           
+        self.text = re.compile(r':([\w_]+)', re.S).sub(repl, text)
+        if bindparams is not None:
+            for b in bindparams:
+                self.bindparams[b.key] = b
+            
     def accept_visitor(self, visitor): 
+        for item in self.bindparams.values():
+            item.accept_visitor(visitor)
         visitor.visit_textclause(self)
     def hash_key(self):
         return "TextClause(%s)" % repr(self.text)
index 970dcbd31ccad7708f99c20fca82b35a815073d4..defd6819bb189ffbed91d72a2c6869ce68823667 100644 (file)
@@ -31,6 +31,10 @@ def adapt_type(typeobj, colspecs):
     to a correctly-configured type instance from the DB-specific package."""
     if type(typeobj) is type:
         typeobj = typeobj()
+    # if the type is not a base type, i.e. not from our module, or its Null, 
+    # we return the type as is
+    if typeobj.__module__ != 'sqlalchemy.types' or typeobj.__class__==NullTypeEngine:
+        return typeobj
     typeobj = typeobj.adapt_args()
     t = typeobj.__class__
     for t in t.__mro__[0:-1]:
index ca0eb0eac19826eebff0309aebeed00912acc6ca..6f4557f5eebc9309ced833d825feb86731c65773 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import *
 import sqlalchemy.ansisql as ansisql
 import sqlalchemy.databases.postgres as postgres
 import sqlalchemy.databases.oracle as oracle
+import sqlalchemy.databases.sqlite as sqlite
 
 db = ansisql.engine()
 
@@ -60,7 +61,10 @@ class SQLTest(PersistTest):
         cc = re.sub(r'\n', '', str(c))
         self.assert_(cc == result, str(c) + "\n does not match \n" + result)
         if checkparams is not None:
-            self.assert_(c.get_params() == checkparams, "params dont match")
+            if isinstance(checkparams, list):
+                self.assert_(c.get_params().values() == checkparams, "params dont match")
+            else:
+                self.assert_(c.get_params() == checkparams, "params dont match")
             
 class SelectTest(SQLTest):
 
@@ -222,6 +226,33 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
         s.append_from("table1")
         self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1", db)
 
+    def testtextbinds(self):
+        self.runtest(
+            db.text("select * from foo where lala=:bar and hoho=:whee"), 
+                "select * from foo where lala=:bar and hoho=:whee", 
+                checkparams={'bar':4, 'whee': 7},
+                params={'bar':4, 'whee': 7, 'hoho':10},
+                engine=db
+        )
+        
+        engine = postgres.engine({})
+        self.runtest(
+            engine.text("select * from foo where lala=:bar and hoho=:whee"), 
+                "select * from foo where lala=%(bar)s and hoho=%(whee)s", 
+                checkparams={'bar':4, 'whee': 7},
+                params={'bar':4, 'whee': 7, 'hoho':10},
+                engine=engine
+        )
+
+        engine = sqlite.engine({})
+        self.runtest(
+            engine.text("select * from foo where lala=:bar and hoho=:whee"), 
+                "select * from foo where lala=? and hoho=?", 
+                checkparams=[4, 7],
+                params={'bar':4, 'whee': 7, 'hoho':10},
+                engine=engine
+        )
+        
     def testtextmix(self):
         self.runtest(select(
             [table, table2.c.id, "sysdate()", "foo, bar, lala"],
index eabf01d1d74e14fed4e749a0c1cec2b6656031fb..155c4ad3b4e10ec9daa79d26bf547afb3b72f37f 100644 (file)
@@ -13,9 +13,9 @@ class OverrideTest(PersistTest):
         class MyType(types.TypeEngine):
             def get_col_spec(self):
                 return "VARCHAR(100)"
-            def convert_bind_param(self, value):
+            def convert_bind_param(self, value, engine):
                 return "BIND_IN"+ value
-            def convert_result_value(self, value):
+            def convert_result_value(self, value, engine):
                 return value + "BIND_OUT"
             def adapt(self, typeobj):
                 return typeobj()
@@ -45,15 +45,16 @@ class OverrideTest(PersistTest):
 class ColumnsTest(AssertMixin):
 
     def testcolumns(self):
-        defaultExpectedResults = { 'int_column': 'int_column INTEGER',
+        expectedResults = { 'int_column': 'int_column INTEGER',
                                    'varchar_column': 'varchar_column VARCHAR(20)',
                                    'numeric_column': 'numeric_column NUMERIC(12, 3)',
                                    'float_column': 'float_column NUMERIC(25, 2)'
                                  }
 
-        if db.engine.__module__ != 'sqlite':
+        if not db.engine.__module__.endswith('sqlite'):
             expectedResults['float_column'] = 'float_column FLOAT(25)'
     
+        print db.engine.__module__
         testTable = Table('testColumns', db,
             Column('int_column', Integer),
             Column('varchar_column', String(20)),
@@ -62,7 +63,7 @@ class ColumnsTest(AssertMixin):
         )
 
         for aCol in testTable.c:
-            self.assertEquals(expectedResults[aCol.name], self.db.schemagenerator(None).get_column_specification(aCol))
+            self.assertEquals(expectedResults[aCol.name], db.schemagenerator(None).get_column_specification(aCol))
         
 
 class BinaryTest(AssertMixin):
@@ -104,14 +105,14 @@ class DateTest(AssertMixin):
         redefine = True
         )
         users_with_date.create()
-    def tearDownAll(self):
-        users_with_date.drop()
-
-    def testdate(self):
         users_with_date.insert().execute(user_id = 7, user_name = 'jack', user_date=datetime.datetime(2005,11,10))
         users_with_date.insert().execute(user_id = 8, user_name = 'roy', user_date=datetime.datetime(2005,11,10, 11,52,35))
         users_with_date.insert().execute(user_id = 9, user_name = 'foo', user_date=datetime.datetime(2005,11,10, 11,52,35, 54839))
         users_with_date.insert().execute(user_id = 10, user_name = 'colber', user_date=None)
+    def tearDownAll(self):
+        users_with_date.drop()
+
+    def testdate(self):
         l = users_with_date.select().execute().fetchall()
         l = [[c for c in r] for r in l]
         if db.engine.__module__.endswith('mysql'):
@@ -121,7 +122,15 @@ class DateTest(AssertMixin):
         print repr(l)
         print repr(x)
         self.assert_(l == x)
-     
+
+    def testtextdate(self):     
+        x = db.text("select user_date from query_users_with_date", typemap={'user_date':DateTime}).execute().fetchall()
+        
+        print repr(x)
+        self.assert_(isinstance(x[0][0], datetime.datetime))
+        
+        #x = db.text("select * from query_users_with_date where user_date=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall()
+        #print repr(x)
         
 if __name__ == "__main__":
     testbase.main()