]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- tiny refactors #1-#5
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Sep 2014 18:50:21 +0000 (14:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Sep 2014 20:28:20 +0000 (16:28 -0400)
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_labels.py

index 2fece76b9fe5fafcc3f96b0b4e89352e02d76bff..a5af6ff193d4af4c610ba492a1731e7a2da789ce 100644 (file)
@@ -472,10 +472,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         """Initialize execution context for a DDLElement construct."""
 
         self = cls.__new__(cls)
-        self.dialect = dialect
         self.root_connection = connection
         self._dbapi_connection = dbapi_connection
-        self.engine = connection.engine
+        self.dialect = connection.dialect
 
         self.compiled = compiled = compiled_ddl
         self.isddl = True
@@ -507,10 +506,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         """Initialize execution context for a Compiled construct."""
 
         self = cls.__new__(cls)
-        self.dialect = dialect
         self.root_connection = connection
         self._dbapi_connection = dbapi_connection
-        self.engine = connection.engine
+        self.dialect = connection.dialect
 
         self.compiled = compiled
 
@@ -538,11 +536,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         self.isupdate = compiled.isupdate
         self.isdelete = compiled.isdelete
 
-        if self.isinsert or self.isupdate or self.isdelete:
-            self._is_explicit_returning = bool(compiled.statement._returning)
-            self._is_implicit_returning = bool(
-                compiled.returning and not compiled.statement._returning)
-
         if not parameters:
             self.compiled_parameters = [compiled.construct_params()]
         else:
@@ -553,11 +546,18 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             self.executemany = len(parameters) > 1
 
         self.cursor = self.create_cursor()
-        if self.isinsert or self.isupdate:
-            self.postfetch_cols = self.compiled.postfetch
-            self.prefetch_cols = self.compiled.prefetch
-            self.returning_cols = self.compiled.returning
-            self.__process_defaults()
+
+        if self.isinsert or self.isupdate or self.isdelete:
+            self._is_explicit_returning = bool(compiled.statement._returning)
+            self._is_implicit_returning = bool(
+                compiled.returning and not compiled.statement._returning)
+
+            if not self.isdelete:
+                if self.compiled.prefetch:
+                    if self.executemany:
+                        self._process_executemany_defaults()
+                    else:
+                        self._process_executesingle_defaults()
 
         processors = compiled._bind_processors
 
@@ -577,21 +577,28 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         else:
             encode = not dialect.supports_unicode_statements
             for compiled_params in self.compiled_parameters:
-                param = {}
+
                 if encode:
-                    for key in compiled_params:
-                        if key in processors:
-                            param[dialect._encoder(key)[0]] = \
-                                processors[key](compiled_params[key])
-                        else:
-                            param[dialect._encoder(key)[0]] = \
-                                compiled_params[key]
+                    param = dict(
+                        (
+                            dialect._encoder(key)[0],
+                            processors[key](compiled_params[key])
+                            if key in processors
+                            else compiled_params[key]
+                        )
+                        for key in compiled_params
+                    )
                 else:
-                    for key in compiled_params:
-                        if key in processors:
-                            param[key] = processors[key](compiled_params[key])
-                        else:
-                            param[key] = compiled_params[key]
+                    param = dict(
+                        (
+                            key,
+                            processors[key](compiled_params[key])
+                            if key in processors
+                            else compiled_params[key]
+                        )
+                        for key in compiled_params
+                    )
+
                 parameters.append(param)
         self.parameters = dialect.execute_sequence_format(parameters)
 
@@ -603,10 +610,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         """Initialize execution context for a string SQL statement."""
 
         self = cls.__new__(cls)
-        self.dialect = dialect
         self.root_connection = connection
         self._dbapi_connection = dbapi_connection
-        self.engine = connection.engine
+        self.dialect = connection.dialect
 
         # plain text statement
         self.execution_options = connection._execution_options
@@ -647,14 +653,29 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         """Initialize execution context for a ColumnDefault construct."""
 
         self = cls.__new__(cls)
-        self.dialect = dialect
         self.root_connection = connection
         self._dbapi_connection = dbapi_connection
-        self.engine = connection.engine
+        self.dialect = connection.dialect
         self.execution_options = connection._execution_options
         self.cursor = self.create_cursor()
         return self
 
+    @util.memoized_property
+    def engine(self):
+        return self.root_connection.engine
+
+    @util.memoized_property
+    def postfetch_cols(self):
+        return self.compiled.postfetch
+
+    @util.memoized_property
+    def prefetch_cols(self):
+        return self.compiled.prefetch
+
+    @util.memoized_property
+    def returning_cols(self):
+        self.compiled.returning
+
     @util.memoized_property
     def no_parameters(self):
         return self.execution_options.get("no_parameters", False)
@@ -779,28 +800,32 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         return self.dialect.supports_sane_multi_rowcount
 
     def post_insert(self):
-        if not self._is_implicit_returning and \
-            not self._is_explicit_returning and \
-            not self.compiled.inline and \
-            self.dialect.postfetch_lastrowid and \
-            (not self.inserted_primary_key or
-             None in self.inserted_primary_key):
-
-            table = self.compiled.statement.table
-            lastrowid = self.get_lastrowid()
-            autoinc_col = table._autoincrement_column
-            if autoinc_col is not None:
-                # apply type post processors to the lastrowid
-                proc = autoinc_col.type._cached_result_processor(
-                    self.dialect, None)
-                if proc is not None:
-                    lastrowid = proc(lastrowid)
 
+        key_getter = self.compiled._key_getters_for_crud_column[2]
+        table = self.compiled.statement.table
+
+        if not self._is_implicit_returning and \
+                not self._is_explicit_returning and \
+                not self.compiled.inline and \
+                self.dialect.postfetch_lastrowid:
+
+                lastrowid = self.get_lastrowid()
+                autoinc_col = table._autoincrement_column
+                if autoinc_col is not None:
+                    # apply type post processors to the lastrowid
+                    proc = autoinc_col.type._cached_result_processor(
+                        self.dialect, None)
+                    if proc is not None:
+                        lastrowid = proc(lastrowid)
+                self.inserted_primary_key = [
+                    lastrowid if c is autoinc_col else
+                    self.compiled_parameters[0].get(key_getter(c), None)
+                    for c in table.primary_key
+                ]
+        else:
             self.inserted_primary_key = [
-                lastrowid if c is autoinc_col else v
-                for c, v in zip(
-                    table.primary_key,
-                    self.inserted_primary_key)
+                self.compiled_parameters[0].get(key_getter(c), None)
+                for c in table.primary_key
             ]
 
     def _fetch_implicit_returning(self, resultproxy):
@@ -823,7 +848,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
     def lastrow_has_defaults(self):
         return (self.isinsert or self.isupdate) and \
-            bool(self.postfetch_cols)
+            bool(self.compiled.postfetch)
 
     def set_input_sizes(self, translate=None, exclude_types=None):
         """Given a cursor and ClauseParameters, call the appropriate
@@ -901,58 +926,52 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         else:
             return self._exec_default(column.onupdate, column.type)
 
-    def __process_defaults(self):
-        """Generate default values for compiled insert/update statements,
-        and generate inserted_primary_key collection.
-        """
-
+    def _process_executemany_defaults(self):
         key_getter = self.compiled._key_getters_for_crud_column[2]
 
-        if self.executemany:
-            if len(self.compiled.prefetch):
-                scalar_defaults = {}
-
-                # pre-determine scalar Python-side defaults
-                # to avoid many calls of get_insert_default()/
-                # get_update_default()
-                for c in self.prefetch_cols:
-                    if self.isinsert and c.default and c.default.is_scalar:
-                        scalar_defaults[c] = c.default.arg
-                    elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
-                        scalar_defaults[c] = c.onupdate.arg
-
-                for param in self.compiled_parameters:
-                    self.current_parameters = param
-                    for c in self.prefetch_cols:
-                        if c in scalar_defaults:
-                            val = scalar_defaults[c]
-                        elif self.isinsert:
-                            val = self.get_insert_default(c)
-                        else:
-                            val = self.get_update_default(c)
-                        if val is not None:
-                            param[key_getter(c)] = val
-                del self.current_parameters
-        else:
-            self.current_parameters = compiled_parameters = \
-                self.compiled_parameters[0]
-
-            for c in self.compiled.prefetch:
-                if self.isinsert:
+        prefetch = self.compiled.prefetch
+        scalar_defaults = {}
+
+        # pre-determine scalar Python-side defaults
+        # to avoid many calls of get_insert_default()/
+        # get_update_default()
+        for c in prefetch:
+            if self.isinsert and c.default and c.default.is_scalar:
+                scalar_defaults[c] = c.default.arg
+            elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
+                scalar_defaults[c] = c.onupdate.arg
+
+        for param in self.compiled_parameters:
+            self.current_parameters = param
+            for c in prefetch:
+                if c in scalar_defaults:
+                    val = scalar_defaults[c]
+                elif self.isinsert:
                     val = self.get_insert_default(c)
                 else:
                     val = self.get_update_default(c)
-
                 if val is not None:
-                    compiled_parameters[key_getter(c)] = val
-            del self.current_parameters
+                    param[key_getter(c)] = val
+        del self.current_parameters
+
+    def _process_executesingle_defaults(self):
+        key_getter = self.compiled._key_getters_for_crud_column[2]
 
+        prefetch = self.compiled.prefetch
+        self.current_parameters = compiled_parameters = \
+            self.compiled_parameters[0]
+
+        for c in prefetch:
             if self.isinsert:
-                self.inserted_primary_key = [
-                    self.compiled_parameters[0].get(key_getter(c), None)
-                    for c in self.compiled.
-                    statement.table.primary_key
-                ]
+                val = self.get_insert_default(c)
+            else:
+                val = self.get_update_default(c)
+
+            if val is not None:
+                compiled_parameters[key_getter(c)] = val
+        del self.current_parameters
+
+
 
 
 DefaultDialect.execution_ctx_cls = DefaultExecutionContext
index dd9df4a6615a140a95a7bc06e91b62c4e9958e9b..74e69e44c51e19a7e5186a01af36bee9d1204651 100644 (file)
@@ -799,9 +799,9 @@ def _postfetch(mapper, uowtransaction, table,
     after an INSERT or UPDATE statement has proceeded for that
     state."""
 
-    prefetch_cols = result.context.prefetch_cols
-    postfetch_cols = result.context.postfetch_cols
-    returning_cols = result.context.returning_cols
+    prefetch_cols = result.context.compiled.prefetch
+    postfetch_cols = result.context.compiled.postfetch
+    returning_cols = result.context.compiled.returning
 
     if mapper.version_id_col is not None:
         prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
index e9252062084841c90ba3e4284e3f0b63d7590a78..af0fff826bb5a3970ab5328468a34ce0c2737429 100644 (file)
@@ -442,11 +442,13 @@ class SQLCompiler(Compiled):
 
         if params:
             pd = {}
-            for bindparam, name in self.bind_names.items():
+            for bindparam in self.bind_names:
+                name = self.bind_names[bindparam]
                 if bindparam.key in params:
                     pd[name] = params[bindparam.key]
                 elif name in params:
                     pd[name] = params[name]
+
                 elif _check and bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
@@ -457,8 +459,11 @@ class SQLCompiler(Compiled):
                         raise exc.InvalidRequestError(
                             "A value is required for bind parameter %r"
                             % bindparam.key)
-                else:
+
+                elif bindparam.callable:
                     pd[name] = bindparam.effective_value
+                else:
+                    pd[name] = bindparam.value
             return pd
         else:
             pd = {}
@@ -473,7 +478,11 @@ class SQLCompiler(Compiled):
                         raise exc.InvalidRequestError(
                             "A value is required for bind parameter %r"
                             % bindparam.key)
-                pd[self.bind_names[bindparam]] = bindparam.effective_value
+
+                if bindparam.callable:
+                    pd[self.bind_names[bindparam]] = bindparam.effective_value
+                else:
+                    pd[self.bind_names[bindparam]] = bindparam.value
             return pd
 
     @property
index 451757b992a709e680db25712a0558107e001ab7..4aa92308091d8f29882d06a3393e8f0d518a3502 100644 (file)
@@ -1,8 +1,10 @@
-from sqlalchemy import exc as exceptions, select, MetaData, Integer, or_
+from sqlalchemy import exc as exceptions, select, MetaData, Integer, or_, \
+    bindparam
 from sqlalchemy.engine import default
 from sqlalchemy.sql import table, column
+from sqlalchemy.sql.elements import _truncated_label
 from sqlalchemy.testing import AssertsCompiledSQL, assert_raises, engines,\
-    fixtures
+    fixtures, eq_
 from sqlalchemy.testing.schema import Table, Column
 
 IDENT_LENGTH = 29
@@ -248,6 +250,47 @@ class MaxIdentTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=self._length_fixture(positional=True)
         )
 
+    def test_bind_param_non_truncated(self):
+        table1 = self.table1
+        stmt = table1.insert().values(
+            this_is_the_data_column=
+            bindparam("this_is_the_long_bindparam_name")
+        )
+        compiled = stmt.compile(dialect=self._length_fixture(length=10))
+        eq_(
+            compiled.construct_params(
+                params={"this_is_the_long_bindparam_name": 5}),
+            {'this_is_the_long_bindparam_name': 5}
+        )
+
+    def test_bind_param_truncated_named(self):
+        table1 = self.table1
+        bp = bindparam(_truncated_label("this_is_the_long_bindparam_name"))
+        stmt = table1.insert().values(
+            this_is_the_data_column=bp
+        )
+        compiled = stmt.compile(dialect=self._length_fixture(length=10))
+        eq_(
+            compiled.construct_params(params={
+                "this_is_the_long_bindparam_name": 5}),
+            {"this_1": 5}
+        )
+
+    def test_bind_param_truncated_positional(self):
+        table1 = self.table1
+        bp = bindparam(_truncated_label("this_is_the_long_bindparam_name"))
+        stmt = table1.insert().values(
+            this_is_the_data_column=bp
+        )
+        compiled = stmt.compile(
+            dialect=self._length_fixture(length=10, positional=True))
+
+        eq_(
+            compiled.construct_params(params={
+                "this_is_the_long_bindparam_name": 5}),
+            {"this_1": 5}
+        )
+
 
 class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'DefaultDialect'