]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the _pk_processors/_prefetch_processors approach relied upon calling RPs without...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Jan 2011 16:53:37 +0000 (11:53 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Jan 2011 16:53:37 +0000 (11:53 -0500)
result, also generates procs that are not used in most cases.  simplify the approach
by passing type to _exec_default() to be used if needed by _execute_scalar(),
looking for the proc on just t._autoincrement_column in post_insert().

lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/dialect/test_postgresql.py

index 4043cd6c3be681e6f79a061c7fb09c4559f47441..feca46bceeb4785a0a76c4e0f4abb65a4ee439bd 100644 (file)
@@ -331,13 +331,13 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
 
 
 class FBExecutionContext(default.DefaultExecutionContext):
-    def fire_sequence(self, seq, proc):
+    def fire_sequence(self, seq, type_):
         """Get the next value from the sequence using ``gen_id()``."""
 
         return self._execute_scalar(
                 "SELECT gen_id(%s, 1) FROM rdb$database" % 
                 self.dialect.identifier_preparer.format_sequence(seq),
-                proc
+                type_
                 )
 
 
index 0b0622f845a9f8cf290e57b7833ba339b79c186f..7ed4ca07e2096122cc441b6f388e20c4f2bcd196 100644 (file)
@@ -593,10 +593,10 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
 
 
 class OracleExecutionContext(default.DefaultExecutionContext):
-    def fire_sequence(self, seq, proc):
+    def fire_sequence(self, seq, type_):
         return int(self._execute_scalar("SELECT " + 
                     self.dialect.identifier_preparer.format_sequence(seq) + 
-                    ".nextval FROM DUAL"), proc)
+                    ".nextval FROM DUAL"), type_)
 
 class OracleDialect(default.DefaultDialect):
     name = 'oracle'
index 84fd96edd1c27f95482f39356276aa4a0f28d158..7c712e8aa36eddeaec89c5949538d96f8e3b9360 100644 (file)
@@ -681,21 +681,21 @@ class DropEnumType(schema._CreateDropBase):
   __visit_name__ = "drop_enum_type"
 
 class PGExecutionContext(default.DefaultExecutionContext):
-    def fire_sequence(self, seq, proc):
+    def fire_sequence(self, seq, type_):
         if not seq.optional:
             return self._execute_scalar(("select nextval('%s')" % \
-                    self.dialect.identifier_preparer.format_sequence(seq)), proc)
+                    self.dialect.identifier_preparer.format_sequence(seq)), type_)
         else:
             return None
 
-    def get_insert_default(self, column, proc):
+    def get_insert_default(self, column):
         if column.primary_key:
             if (isinstance(column.server_default, schema.DefaultClause) and
                 column.server_default.arg is not None):
 
                 # pre-execute passive defaults on primary key columns
                 return self._execute_scalar("select %s" %
-                                        column.server_default.arg, proc)
+                                        column.server_default.arg, column.type)
 
             elif column is column.table._autoincrement_column \
                     and (column.default is None or 
@@ -714,9 +714,9 @@ class PGExecutionContext(default.DefaultExecutionContext):
                     exc = "select nextval('\"%s_%s_seq\"')" % \
                             (column.table.name, column.name)
 
-                return self._execute_scalar(exc, proc)
+                return self._execute_scalar(exc, column.type)
 
-        return super(PGExecutionContext, self).get_insert_default(column, proc)
+        return super(PGExecutionContext, self).get_insert_default(column)
 
 class PGDialect(default.DefaultDialect):
     name = 'postgresql'
index eacbef8f9df44c56f6c3fd712fe21d28ac495815..e21ec1c4070411ab79340015f1b95c356d07dc2b 100644 (file)
@@ -532,7 +532,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         else:
             return autocommit
 
-    def _execute_scalar(self, stmt, proc):
+    def _execute_scalar(self, stmt, type_):
         """Execute a string statement on the current cursor, returning a
         scalar result.
 
@@ -554,10 +554,15 @@ class DefaultExecutionContext(base.ExecutionContext):
 
         conn._cursor_execute(self.cursor, stmt, default_params)
         r = self.cursor.fetchone()[0]
-        if proc:
-            return proc(r)
-        else:
-            return r
+        if type_ is not None:
+            # apply type post processors to the result
+            proc = type_._cached_result_processor(
+                        self.dialect, 
+                        self.cursor.description[0][1]
+                    )
+            if proc:
+                return proc(r)
+        return r
 
     @property
     def connection(self):
@@ -626,15 +631,19 @@ class DefaultExecutionContext(base.ExecutionContext):
 
             table = self.compiled.statement.table
             lastrowid = self.get_lastrowid()
+
+            autoinc_col = table._autoincrement_column
+            if autoinc_col is not None:
+                # apply type post processors to the lastrowid
+                proc = autoinc_col.type._cached_result_processor(self.dialect, None)
+                if proc is not None:
+                    lastrowid = proc(lastrowid)
+
             self.inserted_primary_key = [
-                c is table._autoincrement_column and (
-                    proc and proc(lastrowid)
-                    or lastrowid
-                ) or v
-                for c, v, proc in zip(
+                c is autoinc_col and lastrowid or v
+                for c, v in zip(
                                     table.primary_key, 
-                                    self.inserted_primary_key, 
-                                    self.compiled._pk_processors)
+                                    self.inserted_primary_key)
             ]
 
     def _fetch_implicit_returning(self, resultproxy):
@@ -698,9 +707,9 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self.root_connection._handle_dbapi_exception(e, None, None, None, self)
                 raise
 
-    def _exec_default(self, default, proc):
+    def _exec_default(self, default, type_):
         if default.is_sequence:
-            return self.fire_sequence(default, proc)
+            return self.fire_sequence(default, type_)
         elif default.is_callable:
             return default.arg(self)
         elif default.is_clause_element:
@@ -712,17 +721,17 @@ class DefaultExecutionContext(base.ExecutionContext):
         else:
             return default.arg
 
-    def get_insert_default(self, column, proc):
+    def get_insert_default(self, column):
         if column.default is None:
             return None
         else:
-            return self._exec_default(column.default, proc)
+            return self._exec_default(column.default, column.type)
 
-    def get_update_default(self, column, proc):
+    def get_update_default(self, column):
         if column.onupdate is None:
             return None
         else:
-            return self._exec_default(column.onupdate, proc)
+            return self._exec_default(column.onupdate, column.type)
 
     def __process_defaults(self):
         """Generate default values for compiled insert/update statements,
@@ -744,14 +753,13 @@ class DefaultExecutionContext(base.ExecutionContext):
 
                 for param in self.compiled_parameters:
                     self.current_parameters = param
-                    for c, proc in zip(self.prefetch_cols, 
-                                        self.compiled._prefetch_processors):
+                    for c in self.prefetch_cols:
                         if c in scalar_defaults:
                             val = scalar_defaults[c]
                         elif self.isinsert:
-                            val = self.get_insert_default(c, proc)
+                            val = self.get_insert_default(c)
                         else:
-                            val = self.get_update_default(c, proc)
+                            val = self.get_update_default(c)
                         if val is not None:
                             param[c.key] = val
                 del self.current_parameters
@@ -759,12 +767,11 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.current_parameters = compiled_parameters = \
                                         self.compiled_parameters[0]
 
-            for c, proc in zip(self.compiled.prefetch, 
-                                        self.compiled._prefetch_processors):
+            for c in self.compiled.prefetch:
                 if self.isinsert:
-                    val = self.get_insert_default(c, proc)
+                    val = self.get_insert_default(c)
                 else:
-                    val = self.get_update_default(c, proc)
+                    val = self.get_update_default(c)
 
                 if val is not None:
                     compiled_parameters[c.key] = val
index 92110ca2a49d444404d48946437addc9fa60f97b..ce98dfb83cef40009ff6aea647565ce67a1c43e2 100644 (file)
@@ -268,20 +268,6 @@ class SQLCompiler(engine.Compiled):
                  if value is not None
             )
 
-    @util.memoized_property
-    def _pk_processors(self):
-        return [
-                col.type._cached_result_processor(self.dialect, None) 
-                for col in self.statement.table.primary_key
-            ]
-
-    @util.memoized_property
-    def _prefetch_processors(self):
-        return [
-                col.type._cached_result_processor(self.dialect, None) 
-                for col in self.prefetch
-            ]
-
     def is_subquery(self):
         return len(self.stack) > 1
 
index a8c63c566b70a6d0763d20dabf76df06eb43d1c1..fb5a63c9b4689ead6aa568d1ee840610da4561a6 100644 (file)
@@ -476,7 +476,7 @@ class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             metadata.drop_all()
 
 class NumericInterpretationTest(TestBase):
-
+    __only_on__ = 'postgresql'
 
     def test_numeric_codes(self):
         from sqlalchemy.dialects.postgresql import pg8000, psycopg2, base
@@ -493,6 +493,28 @@ class NumericInterpretationTest(TestBase):
                     val = proc(val)
                 assert val in (23.7, decimal.Decimal("23.7"))
 
+    @testing.provide_metadata
+    def test_numeric_default(self):
+        t =Table('t', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('nd', Numeric(asdecimal=True), default=0),
+            Column('nf', Numeric(asdecimal=False), default=0),
+            Column('fd', Float(asdecimal=True), default=0),
+            Column('ff', Float(asdecimal=False), default=0),
+        )
+        metadata.create_all()
+        r = t.insert().execute()
+
+        row = t.select().execute().first()
+        assert isinstance(row[1], decimal.Decimal)
+        assert isinstance(row[2], float)
+        assert isinstance(row[3], decimal.Decimal)
+        assert isinstance(row[4], float)
+        eq_(
+            row,
+            (1, decimal.Decimal("0"), 0, decimal.Decimal("0"), 0)
+        )
+
 class InsertTest(TestBase, AssertsExecutionResults):
 
     __only_on__ = 'postgresql'