]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support lightweight compiler column elements w/ slots
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Nov 2021 15:59:06 +0000 (10:59 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Nov 2021 16:26:33 +0000 (11:26 -0500)
the _CompileLabel class included ``__slots__`` but these
weren't used as the superclasses included slots.

Create a ``__slots__`` superclass for ``ClauseElement``,
creating a new class of compilable SQL elements that don't
include heavier features like caching, annotations and
cloning, which are meant to be used only in an ad-hoc
compiler fashion.   Create new ``CompilerColumnElement``
from that which serves in column-oriented contexts, but
similarly does not include any expression operator support
as it is intended to be used only to generate a string.

Apply this to both
``_CompileLabel`` as well as PostgreSQL ``_ColonCast``,
which does not actually subclass ``ColumnElement`` as this
class has memoized attributes that aren't worth changing,
and does not include SQL operator capabilities as these
are not needed for these compiler-only objects.

this allows us to more inexpensively add new ad-hoc
labels / casts etc. at compile time, as we will be seeking
to expand out the typecasts that are needed for PostgreSQL
dialects in a subsequent patch.

Change-Id: I52973ae3295cb6e2eb0d7adc816c678a626643ed

lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
test/dialect/postgresql/test_compiler.py
test/sql/test_compiler.py

index b4c57e694d9f6c988ebe1b0ed441b49182b8cc38..583d9c2630820909f5d68549e05dae4f4add7005 100644 (file)
@@ -2041,8 +2041,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             self.drop(bind=bind, checkfirst=checkfirst)
 
 
-class _ColonCast(elements.Cast):
+class _ColonCast(elements.CompilerColumnElement):
     __visit_name__ = "colon_cast"
+    __slots__ = ("type", "clause", "typeclause")
 
     def __init__(self, expression, type_):
         self.type = type_
index 88f045fe66d8febdff6674b948170c7a7379ad4e..519a3103bb713c30a4269c8532bdb01643224633 100644 (file)
@@ -21,6 +21,8 @@ EMPTY_ANNOTATIONS = util.immutabledict()
 
 
 class SupportsAnnotations:
+    __slots__ = ()
+
     _annotations = EMPTY_ANNOTATIONS
 
     @util.memoized_property
@@ -44,6 +46,7 @@ class SupportsAnnotations:
 
 
 class SupportsCloneAnnotations(SupportsAnnotations):
+    __slots__ = ()
 
     _clone_annotations_traverse_internals = [
         ("_annotations", InternalTraversal.dp_annotations_key)
@@ -92,6 +95,8 @@ class SupportsCloneAnnotations(SupportsAnnotations):
 
 
 class SupportsWrappingAnnotations(SupportsAnnotations):
+    __slots__ = ()
+
     def _annotate(self, values):
         """return a copy of this ClauseElement with annotations
         updated by the given dictionary.
index 482afb42f50636a5cc78190fa08bd7767968dbee..29aa57faab6e0c92a3859cca74d0ae84e2bc525a 100644 (file)
@@ -525,12 +525,12 @@ class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
 
 # this was a Visitable, but to allow accurate detection of
 # column elements this is actually a column element
-class _CompileLabel(elements.ColumnElement):
+class _CompileLabel(elements.CompilerColumnElement):
 
     """lightweight label object which acts as an expression.Label."""
 
     __visit_name__ = "label"
-    __slots__ = "element", "name"
+    __slots__ = "element", "name", "_alt_names"
 
     def __init__(self, col, name, alt_names=()):
         self.element = col
index 6893d99eee018f3fd205b24d3e82371407c30e06..ca65f211220a015afebf77ab82cc2c8607348a9e 100644 (file)
@@ -170,13 +170,103 @@ def not_(clause):
     return operators.inv(coercions.expect(roles.ExpressionElementRole, clause))
 
 
+class CompilerElement(Traversible):
+    """base class for SQL elements that can be compiled to produce a
+    SQL string.
+
+    .. versionadded:: 2.0
+
+    """
+
+    __slots__ = ()
+    __visit_name__ = "compiler_element"
+
+    supports_execution = False
+
+    stringify_dialect = "default"
+
+    @util.preload_module("sqlalchemy.engine.default")
+    @util.preload_module("sqlalchemy.engine.url")
+    def compile(self, bind=None, dialect=None, **kw):
+        """Compile this SQL expression.
+
+        The return value is a :class:`~.Compiled` object.
+        Calling ``str()`` or ``unicode()`` on the returned value will yield a
+        string representation of the result. The
+        :class:`~.Compiled` object also can return a
+        dictionary of bind parameter names and values
+        using the ``params`` accessor.
+
+        :param bind: An ``Engine`` or ``Connection`` from which a
+            ``Compiled`` will be acquired. This argument takes precedence over
+            this :class:`_expression.ClauseElement`'s bound engine, if any.
+
+        :param column_keys: Used for INSERT and UPDATE statements, a list of
+            column names which should be present in the VALUES clause of the
+            compiled statement. If ``None``, all columns from the target table
+            object are rendered.
+
+        :param dialect: A ``Dialect`` instance from which a ``Compiled``
+            will be acquired. This argument takes precedence over the `bind`
+            argument as well as this :class:`_expression.ClauseElement`
+            's bound engine,
+            if any.
+
+        :param compile_kwargs: optional dictionary of additional parameters
+            that will be passed through to the compiler within all "visit"
+            methods.  This allows any custom flag to be passed through to
+            a custom compilation construct, for example.  It is also used
+            for the case of passing the ``literal_binds`` flag through::
+
+                from sqlalchemy.sql import table, column, select
+
+                t = table('t', column('x'))
+
+                s = select(t).where(t.c.x == 5)
+
+                print(s.compile(compile_kwargs={"literal_binds": True}))
+
+            .. versionadded:: 0.9.0
+
+        .. seealso::
+
+            :ref:`faq_sql_expression_string`
+
+        """
+
+        if not dialect:
+            if bind:
+                dialect = bind.dialect
+            elif self.bind:
+                dialect = self.bind.dialect
+            else:
+                if self.stringify_dialect == "default":
+                    default = util.preloaded.engine_default
+                    dialect = default.StrCompileDialect()
+                else:
+                    url = util.preloaded.engine_url
+                    dialect = url.URL.create(
+                        self.stringify_dialect
+                    ).get_dialect()()
+
+        return self._compiler(dialect, **kw)
+
+    def _compiler(self, dialect, **kw):
+        """Return a compiler appropriate for this ClauseElement, given a
+        Dialect."""
+
+        return dialect.statement_compiler(dialect, self, **kw)
+
+    def __str__(self):
+        return str(self.compile())
+
+
 @inspection._self_inspects
 class ClauseElement(
-    roles.SQLRole,
     SupportsWrappingAnnotations,
     MemoizedHasCacheKey,
     HasCopyInternals,
-    Traversible,
+    CompilerElement,
 ):
     """Base class for elements of a programmatically constructed SQL
     expression.
@@ -191,10 +281,6 @@ class ClauseElement(
 
     """
 
-    supports_execution = False
-
-    stringify_dialect = "default"
-
     _from_objects = []
     bind = None
     description = None
@@ -423,72 +509,6 @@ class ClauseElement(
 
         return self
 
-    @util.preload_module("sqlalchemy.engine.default")
-    @util.preload_module("sqlalchemy.engine.url")
-    def compile(self, bind=None, dialect=None, **kw):
-        """Compile this SQL expression.
-
-        The return value is a :class:`~.Compiled` object.
-        Calling ``str()`` or ``unicode()`` on the returned value will yield a
-        string representation of the result. The
-        :class:`~.Compiled` object also can return a
-        dictionary of bind parameter names and values
-        using the ``params`` accessor.
-
-        :param bind: An ``Engine`` or ``Connection`` from which a
-            ``Compiled`` will be acquired. This argument takes precedence over
-            this :class:`_expression.ClauseElement`'s bound engine, if any.
-
-        :param column_keys: Used for INSERT and UPDATE statements, a list of
-            column names which should be present in the VALUES clause of the
-            compiled statement. If ``None``, all columns from the target table
-            object are rendered.
-
-        :param dialect: A ``Dialect`` instance from which a ``Compiled``
-            will be acquired. This argument takes precedence over the `bind`
-            argument as well as this :class:`_expression.ClauseElement`
-            's bound engine,
-            if any.
-
-        :param compile_kwargs: optional dictionary of additional parameters
-            that will be passed through to the compiler within all "visit"
-            methods.  This allows any custom flag to be passed through to
-            a custom compilation construct, for example.  It is also used
-            for the case of passing the ``literal_binds`` flag through::
-
-                from sqlalchemy.sql import table, column, select
-
-                t = table('t', column('x'))
-
-                s = select(t).where(t.c.x == 5)
-
-                print(s.compile(compile_kwargs={"literal_binds": True}))
-
-            .. versionadded:: 0.9.0
-
-        .. seealso::
-
-            :ref:`faq_sql_expression_string`
-
-        """
-
-        if not dialect:
-            if bind:
-                dialect = bind.dialect
-            elif self.bind:
-                dialect = self.bind.dialect
-            else:
-                if self.stringify_dialect == "default":
-                    default = util.preloaded.engine_default
-                    dialect = default.StrCompileDialect()
-                else:
-                    url = util.preloaded.engine_url
-                    dialect = url.URL.create(
-                        self.stringify_dialect
-                    ).get_dialect()()
-
-        return self._compiler(dialect, **kw)
-
     def _compile_w_cache(
         self,
         dialect,
@@ -547,20 +567,6 @@ class ClauseElement(
 
         return compiled_sql, extracted_params, cache_hit
 
-    def _compiler(self, dialect, **kw):
-        """Return a compiler appropriate for this ClauseElement, given a
-        Dialect."""
-
-        return dialect.statement_compiler(dialect, self, **kw)
-
-    def __str__(self):
-        if util.py3k:
-            return str(self.compile())
-        else:
-            return unicode(self.compile()).encode(  # noqa
-                "ascii", "backslashreplace"
-            )  # noqa
-
     def __invert__(self):
         # undocumented element currently used by the ORM for
         # relationship.contains()
@@ -592,6 +598,21 @@ class ClauseElement(
             )
 
 
+class CompilerColumnElement(
+    roles.DMLColumnRole,
+    roles.DDLConstraintColumnRole,
+    roles.ColumnsClauseRole,
+    CompilerElement,
+):
+    """A compiler-only column element used for ad-hoc string compilations.
+
+    .. versionadded:: 2.0
+
+    """
+
+    __slots__ = ()
+
+
 class ColumnElement(
     roles.ColumnArgumentOrKeyRole,
     roles.StatementOptionRole,
@@ -684,6 +705,7 @@ class ColumnElement(
     """
 
     __visit_name__ = "column_element"
+
     primary_key = False
     foreign_keys = []
     _proxies = ()
index 4e009aa269c34d08f252689be3e502491f58dfc8..c4eedd4a4e0d24deb222ae78409bd1c059c387fe 100644 (file)
@@ -19,48 +19,60 @@ class SQLRole:
 
     """
 
+    __slots__ = ()
     allows_lambda = False
     uses_inspection = False
 
 
 class UsesInspection:
+    __slots__ = ()
     _post_inspect = None
     uses_inspection = True
 
 
 class AllowsLambdaRole:
+    __slots__ = ()
     allows_lambda = True
 
 
 class HasCacheKeyRole(SQLRole):
+    __slots__ = ()
     _role_name = "Cacheable Core or ORM object"
 
 
 class LiteralValueRole(SQLRole):
+    __slots__ = ()
     _role_name = "Literal Python value"
 
 
 class ColumnArgumentRole(SQLRole):
+    __slots__ = ()
     _role_name = "Column expression"
 
 
 class ColumnArgumentOrKeyRole(ColumnArgumentRole):
+    __slots__ = ()
     _role_name = "Column expression or string key"
 
 
 class StrAsPlainColumnRole(ColumnArgumentRole):
+    __slots__ = ()
     _role_name = "Column expression or string key"
 
 
 class ColumnListRole(SQLRole):
     """Elements suitable for forming comma separated lists of expressions."""
 
+    __slots__ = ()
+
 
 class TruncatedLabelRole(SQLRole):
+    __slots__ = ()
     _role_name = "String SQL identifier"
 
 
 class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
+    __slots__ = ()
     _role_name = "Column expression or FROM clause"
 
     @property
@@ -69,14 +81,17 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
 
 
 class LimitOffsetRole(SQLRole):
+    __slots__ = ()
     _role_name = "LIMIT / OFFSET expression"
 
 
 class ByOfRole(ColumnListRole):
+    __slots__ = ()
     _role_name = "GROUP BY / OF / etc. expression"
 
 
 class GroupByRole(AllowsLambdaRole, UsesInspection, ByOfRole):
+    __slots__ = ()
     # note there's a special case right now where you can pass a whole
     # ORM entity to group_by() and it splits out.   we may not want to keep
     # this around
@@ -85,48 +100,57 @@ class GroupByRole(AllowsLambdaRole, UsesInspection, ByOfRole):
 
 
 class OrderByRole(AllowsLambdaRole, ByOfRole):
+    __slots__ = ()
     _role_name = "ORDER BY expression"
 
 
 class StructuralRole(SQLRole):
-    pass
+    __slots__ = ()
 
 
 class StatementOptionRole(StructuralRole):
+    __slots__ = ()
     _role_name = "statement sub-expression element"
 
 
 class OnClauseRole(AllowsLambdaRole, StructuralRole):
+    __slots__ = ()
     _role_name = "SQL expression for ON clause"
 
 
 class WhereHavingRole(OnClauseRole):
+    __slots__ = ()
     _role_name = "SQL expression for WHERE/HAVING role"
 
 
 class ExpressionElementRole(SQLRole):
+    __slots__ = ()
     _role_name = "SQL expression element"
 
 
 class ConstExprRole(ExpressionElementRole):
+    __slots__ = ()
     _role_name = "Constant True/False/None expression"
 
 
 class LabeledColumnExprRole(ExpressionElementRole):
-    pass
+    __slots__ = ()
 
 
 class BinaryElementRole(ExpressionElementRole):
+    __slots__ = ()
     _role_name = "SQL expression element or literal value"
 
 
 class InElementRole(SQLRole):
+    __slots__ = ()
     _role_name = (
         "IN expression list, SELECT construct, or bound parameter object"
     )
 
 
 class JoinTargetRole(AllowsLambdaRole, UsesInspection, StructuralRole):
+    __slots__ = ()
     _role_name = (
         "Join target, typically a FROM expression, or ORM "
         "relationship attribute"
@@ -134,6 +158,7 @@ class JoinTargetRole(AllowsLambdaRole, UsesInspection, StructuralRole):
 
 
 class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
+    __slots__ = ()
     _role_name = "FROM expression, such as a Table or alias() object"
 
     _is_subquery = False
@@ -144,6 +169,7 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
 
 
 class StrictFromClauseRole(FromClauseRole):
+    __slots__ = ()
     # does not allow text() or select() objects
 
     @property
@@ -152,6 +178,7 @@ class StrictFromClauseRole(FromClauseRole):
 
 
 class AnonymizedFromClauseRole(StrictFromClauseRole):
+    __slots__ = ()
     # calls .alias() as a post processor
 
     def _anonymous_fromclause(self, name=None, flat=False):
@@ -159,6 +186,7 @@ class AnonymizedFromClauseRole(StrictFromClauseRole):
 
 
 class ReturnsRowsRole(SQLRole):
+    __slots__ = ()
     _role_name = (
         "Row returning expression such as a SELECT, a FROM clause, or an "
         "INSERT/UPDATE/DELETE with RETURNING"
@@ -166,12 +194,14 @@ class ReturnsRowsRole(SQLRole):
 
 
 class StatementRole(SQLRole):
+    __slots__ = ()
     _role_name = "Executable SQL or text() construct"
 
     _propagate_attrs = util.immutabledict()
 
 
 class SelectStatementRole(StatementRole, ReturnsRowsRole):
+    __slots__ = ()
     _role_name = "SELECT construct or equivalent text() construct"
 
     def subquery(self):
@@ -182,16 +212,18 @@ class SelectStatementRole(StatementRole, ReturnsRowsRole):
 
 
 class HasCTERole(ReturnsRowsRole):
-    pass
+    __slots__ = ()
 
 
 class IsCTERole(SQLRole):
+    __slots__ = ()
     _role_name = "CTE object"
 
 
 class CompoundElementRole(AllowsLambdaRole, SQLRole):
     """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc."""
 
+    __slots__ = ()
     _role_name = (
         "SELECT construct for inclusion in a UNION or other set construct"
     )
@@ -199,36 +231,42 @@ class CompoundElementRole(AllowsLambdaRole, SQLRole):
 
 # TODO: are we using this?
 class DMLRole(StatementRole):
-    pass
+    __slots__ = ()
 
 
 class DMLTableRole(FromClauseRole):
+    __slots__ = ()
     _role_name = "subject table for an INSERT, UPDATE or DELETE"
 
 
 class DMLColumnRole(SQLRole):
+    __slots__ = ()
     _role_name = "SET/VALUES column expression or string key"
 
 
 class DMLSelectRole(SQLRole):
     """A SELECT statement embedded in DML, typically INSERT from SELECT"""
 
+    __slots__ = ()
     _role_name = "SELECT statement or equivalent textual object"
 
 
 class DDLRole(StatementRole):
-    pass
+    __slots__ = ()
 
 
 class DDLExpressionRole(StructuralRole):
+    __slots__ = ()
     _role_name = "SQL expression element for DDL constraint"
 
 
 class DDLConstraintColumnRole(SQLRole):
+    __slots__ = ()
     _role_name = "String column name or column expression for DDL constraint"
 
 
 class DDLReferredColumnRole(DDLConstraintColumnRole):
+    __slots__ = ()
     _role_name = (
         "String column name or Column object for DDL foreign key constraint"
     )
index 6acd794aa3ac109965f06257c734a0f39f8743ea..7973b535f7cfa9083d1bb84252ef3b734582399c 100644 (file)
@@ -714,6 +714,8 @@ _cache_key_traversal_visitor = _CacheKey()
 
 
 class HasCopyInternals:
+    __slots__ = ()
+
     def _clone(self, **kw):
         raise NotImplementedError()
 
index 82cb7a253c4bb2a38b34a5aba3e751d755432aa0..deb92b08111dbfe038484a8975cf16eea082b8a7 100644 (file)
@@ -120,6 +120,8 @@ class Traversible(util.with_metaclass(TraversibleType)):
 
     """
 
+    __slots__ = ()
+
     def __class_getitem__(cls, key):
         # allow generic classes in py3.9+
         return cls
index 93513c39dbe1e0f7b3dd5a3351a79fd0fe8c52b2..7e91f0ebb459c8a907fffa3fc32d5de765d9bcd3 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.dialects.postgresql.base import _ColonCast
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
 from sqlalchemy.orm import aliased
@@ -98,6 +99,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __dialect__ = postgresql.dialect()
 
+    def test_colon_cast_is_slots(self):
+
+        c1 = _ColonCast(column("q"), String(50))
+
+        assert not hasattr(c1, "__dict__")
+
+        self.assert_compile(c1, "q::VARCHAR(50)")
+
     def test_update_returning(self):
         dialect = postgresql.dialect()
         table1 = table(
index 4eef97369e773b5ab034804098e0136384b7a149..23a2833ca387838ae2f405a5b8f55ba829746400 100644 (file)
@@ -79,6 +79,7 @@ from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql.elements import BooleanClauseList
 from sqlalchemy.sql.elements import ColumnElement
+from sqlalchemy.sql.elements import CompilerColumnElement
 from sqlalchemy.sql.expression import ClauseElement
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.sql.expression import HasPrefixes
@@ -215,6 +216,25 @@ class TestCompilerFixture(fixtures.TestBase, AssertsCompiledSQL):
 class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
+    def test_compiler_column_element_is_slots(self):
+        class SomeColThing(CompilerColumnElement):
+            __slots__ = ("name",)
+            __visit_name__ = "some_col_thing"
+
+            def __init__(self, name):
+                self.name = name
+
+        c1 = SomeColThing("some name")
+        eq_(c1.name, "some name")
+        assert not hasattr(c1, "__dict__")
+
+    def test_compile_label_is_slots(self):
+
+        c1 = compiler._CompileLabel(column("q"), "somename")
+
+        eq_(c1.name, "somename")
+        assert not hasattr(c1, "__dict__")
+
     def test_attribute_sanity(self):
         assert hasattr(table1, "c")
         assert hasattr(table1.select().subquery(), "c")