]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Got PG server side cursors back into shape, added fixed
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 2 Apr 2008 22:33:50 +0000 (22:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 2 Apr 2008 22:33:50 +0000 (22:33 +0000)
unit tests as part of the default test suite.  Added
better uniqueness to the cursor ID [ticket:1001]
- update().values() and insert().values() take keyword
arguments.

CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/expression.py
test/dialect/postgres.py

diff --git a/CHANGES b/CHANGES
index 7207f1322a7af946dc4147d846d7cc90e239ec38..4f78ceb94ca239f95efaa07a9c40d8adf1bd81fa 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -178,6 +178,9 @@ CHANGES
     - random() is now a generic sql function and will compile to
       the database's random implementation, if any.
 
+    - update().values() and insert().values() take keyword 
+      arguments.
+
     - Fixed an issue in select() regarding its generation of FROM
       clauses, in rare circumstances two clauses could be produced
       when one was intended to cancel out the other.  Some ORM
@@ -218,6 +221,11 @@ CHANGES
        property can also be set on individual declarative
        classes using the "__mapper_cls__" property.
 
+- postgres
+    - Got PG server side cursors back into shape, added fixed
+      unit tests as part of the default test suite.  Added
+      better uniqueness to the cursor ID [ticket:1001]
+      
 - oracle
     - The "owner" keyword on Table is now deprecated, and is
       exactly synonymous with the "schema" keyword.  Tables can
index 94ad7d2e45d36204b0c5e8e4bcee6b39c0e0592f..326dd6b7d56729d258994e5d29edc76f7c94e94c 100644 (file)
@@ -224,6 +224,10 @@ def descriptor():
         ('host',"Hostname", None),
     ]}
 
+SERVER_SIDE_CURSOR_RE = re.compile(
+    r'\s*SELECT',
+    re.I | re.UNICODE)
+
 SELECT_RE = re.compile(
     r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))',
     re.I | re.UNICODE)
@@ -252,7 +256,6 @@ RETURNING_QUOTED_RE = re.compile(
         \sRETURNING\s""", re.I | re.UNICODE | re.VERBOSE)
 
 class PGExecutionContext(default.DefaultExecutionContext):
-
     def returns_rows_text(self, statement):
         m = SELECT_RE.match(statement)
         return m and (not m.group(1) or (RETURNING_RE.search(statement)
@@ -265,23 +268,20 @@ class PGExecutionContext(default.DefaultExecutionContext):
             )
 
     def create_cursor(self):
-        # executing a default or Sequence standalone creates an execution context without a statement.
-        # so slightly hacky "if no statement assume we're server side" logic
-        # TODO: dont use regexp if Compiled is used ?
         self.__is_server_side = \
             self.dialect.server_side_cursors and \
-            (self.statement is None or \
-            (SELECT_RE.match(self.statement) and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I))
-        )
+            ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)) \
+            or \
+            (not self.compiled and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)))
 
         if self.__is_server_side:
             # use server-side cursors:
             # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            ident = "c" + hex(random.randint(0, 65535))[2:]
+            ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
             return self._connection.connection.cursor(ident)
         else:
             return self._connection.connection.cursor()
-
+    
     def get_result_proxy(self):
         if self.__is_server_side:
             return base.BufferedRowResultProxy(self)
@@ -763,6 +763,11 @@ class PGSchemaDropper(compiler.SchemaDropper):
             self.execute()
 
 class PGDefaultRunner(base.DefaultRunner):
+    def __init__(self, context):
+        base.DefaultRunner.__init__(self, context)
+        # craete cursor which won't conflict with a server-side cursor
+        self.cursor = context._connection.connection.cursor()
+    
     def get_column_default(self, column, isinsert=True):
         if column.primary_key:
             # pre-execute passive defaults on primary keys
index cd662ac92e833eb3b11a9d4cb5d5058240a9ee0f..722acd6f0fedb7fc655b28055ca4fe75860519a2 100644 (file)
@@ -1806,6 +1806,7 @@ class DefaultRunner(schema.SchemaVisitor):
     def __init__(self, context):
         self.context = context
         self.dialect = context.dialect
+        self.cursor = context.cursor
 
     def get_column_default(self, column):
         if column.default is not None:
@@ -1846,8 +1847,8 @@ class DefaultRunner(schema.SchemaVisitor):
         conn = self.context._connection
         if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
             stmt = stmt.encode(self.dialect.encoding)
-        conn._cursor_execute(self.context.cursor, stmt, params)
-        return self.context.cursor.fetchone()[0]
+        conn._cursor_execute(self.cursor, stmt, params)
+        return self.cursor.fetchone()[0]
 
     def visit_column_onupdate(self, onupdate):
         if isinstance(onupdate.arg, expression.ClauseElement):
index d8a85d8bd85bb8612373778013b6f5a29e04fe6b..c487ee173d01df4a621a72d408bb001398a84f39 100644 (file)
@@ -3496,7 +3496,35 @@ class _UpdateBase(ClauseElement):
         self._bind = bind
     bind = property(bind, _set_bind)
 
-class Insert(_UpdateBase):
+class _ValuesBase(_UpdateBase):
+    def values(self, *args, **kwargs):
+        """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE.
+
+            \**kwargs
+                key=<somevalue> arguments
+                
+            \*args
+                deprecated.  A single dictionary can be sent as the first positional argument.
+        """
+        
+        if args:
+            v = args[0]
+        else:
+            v = {}
+        if len(v) == 0 and len(kwargs) == 0:
+            return self
+        u = self._clone()
+        
+        if u.parameters is None:
+            u.parameters = u._process_colparams(v)
+            u.parameters.update(kwargs)
+        else:
+            u.parameters = self.parameters.copy()
+            u.parameters.update(u._process_colparams(v))
+            u.parameters.update(kwargs)
+        return u
+
+class Insert(_ValuesBase):
     def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs):
         self._bind = bind
         self.table = table
@@ -3520,17 +3548,6 @@ class Insert(_UpdateBase):
     def _copy_internals(self, clone=_clone):
         self.parameters = self.parameters.copy()
 
-    def values(self, v):
-        if len(v) == 0:
-            return self
-        u = self._clone()
-        if u.parameters is None:
-            u.parameters = u._process_colparams(v)
-        else:
-            u.parameters = self.parameters.copy()
-            u.parameters.update(u._process_colparams(v))
-        return u
-
     def prefix_with(self, clause):
         """Add a word or expression between INSERT and INTO. Generative.
 
@@ -3542,7 +3559,7 @@ class Insert(_UpdateBase):
         gen._prefixes = self._prefixes + [clause]
         return gen
 
-class Update(_UpdateBase):
+class Update(_ValuesBase):
     def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
         self._bind = bind
         self.table = table
@@ -3576,16 +3593,6 @@ class Update(_UpdateBase):
             s._whereclause = _literal_as_text(whereclause)
         return s
 
-    def values(self, v):
-        if len(v) == 0:
-            return self
-        u = self._clone()
-        if u.parameters is None:
-            u.parameters = u._process_colparams(v)
-        else:
-            u.parameters = self.parameters.copy()
-            u.parameters.update(u._process_colparams(v))
-        return u
 
 class Delete(_UpdateBase):
     def __init__(self, table, whereclause, bind=None):
index e68fd4d746c7804d8ef3f13fc7daaa7414ab3313..90cc0a47742a7c5b7218ca8a74bca1236b5be4a4 100644 (file)
@@ -780,6 +780,37 @@ class TimeStampTest(TestBase, AssertsExecutionResults):
         result = connection.execute(s).fetchone() 
         self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0)) 
 
-
+class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
+    __only_on__ = 'postgres'
+    
+    def setUpAll(self):
+        global ss_engine
+        ss_engine = engines.testing_engine(options={'server_side_cursors':True})
+        
+    def tearDownAll(self):
+        ss_engine.dispose()
+    
+    def test_roundtrip(self):
+        test_table = Table('test_table', MetaData(ss_engine),
+            Column('id', Integer, primary_key=True),
+            Column('data', String(50))
+        )
+        test_table.create(checkfirst=True)
+        try:
+            test_table.insert().execute(data='data1')
+            
+            nextid = ss_engine.execute(Sequence('test_table_id_seq'))
+            test_table.insert().execute(id=nextid, data='data2')
+            
+            self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')])
+            
+            test_table.update().where(test_table.c.id==2).values(data=test_table.c.data + ' updated').execute()
+            self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')])
+            test_table.delete().execute()
+            self.assertEquals(test_table.count().scalar(), 0)
+        finally:
+            test_table.drop(checkfirst=True)
+            
+    
 if __name__ == "__main__":
     testenv.main()