From: Mike Bayer Date: Wed, 17 Apr 2019 17:37:39 +0000 (-0400) Subject: Add _cache_key implementation. X-Git-Tag: rel_1_4_0b1~892^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=08da8115a6eb7eb125fa5f92f662d915b076fded;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add _cache_key implementation. This leverages the work started in #4336 to allow ClauseElement structures to be cachable based on structure, not just identity. Change-Id: Ia99ddeb5353496dd7d61243245685f02b98d8100 --- diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 38c7cf840e..e634e5a367 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -218,6 +218,28 @@ class ClauseElement(Visitable): return c + def _cache_key(self, **kw): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, it should raise + NotImplementedError, which will result in the entire structure + for which it's part of not being useful as a cache key. + + + """ + raise NotImplementedError(self.__class__) + @property def _constructor(self): """return the 'constructor' for this ClauseElement. @@ -712,6 +734,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): else: return comparator_factory(self) + def _cache_key(self, **kw): + raise NotImplementedError(self.__class__) + def __getattr__(self, key): try: return getattr(self.comparator, key) @@ -1108,7 +1133,10 @@ class BindParameter(ColumnElement): if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: + self._value_required_for_cache = False value = None + else: + self._value_required_for_cache = True if quote is not None: key = quoted_name(key, quote) @@ -1192,6 +1220,26 @@ class BindParameter(ColumnElement): ) return c + def _cache_key(self, bindparams=None, **kw): + if bindparams is None: + # even though _cache_key is a private method, we would like to + # be super paranoid about this point. You can't include the + # "value" or "callable" in the cache key, because the value is + # not part of the structure of a statement and is likely to + # change every time. However you cannot *throw it away* either, + # because you can't invoke the statement without the parameter + # values that were explicitly placed. So require that they + # are collected here to make sure this happens. + if self._value_required_for_cache: + raise NotImplementedError( + "bindparams collection argument required for _cache_key " + "implementation. Bound parameter cache keys are not safe " + "to use without accommodating for the value or callable " + "within the parameter itself.") + else: + bindparams.append(self) + return (BindParameter, self.type._cache_key, self._orig_key) + def _convert_to_unique(self): if not self.unique: self.unique = True @@ -1230,6 +1278,9 @@ class TypeClause(ClauseElement): def __init__(self, type_): self.type = type_ + def _cache_key(self, **kw): + return (TypeClause, self.type._cache_key) + class TextClause(Executable, ClauseElement): """Represent a literal SQL text fragment. @@ -1658,6 +1709,11 @@ class TextClause(Executable, ClauseElement): def get_children(self, **kwargs): return list(self._bindparams.values()) + def _cache_key(self, **kw): + return (self.text,) + tuple( + bind._cache_key for bind in self._bindparams.values() + ) + class Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -1679,6 +1735,9 @@ class Null(ColumnElement): return Null() + def _cache_key(self, **kw): + return (Null,) + class False_(ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -1735,6 +1794,9 @@ class False_(ColumnElement): return False_() + def _cache_key(self, **kw): + return (False_,) + class True_(ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. @@ -1798,6 +1860,9 @@ class True_(ColumnElement): return True_() + def _cache_key(self, **kw): + return (True_,) + class ClauseList(ClauseElement): """Describe a list of clauses, separated by an operator. @@ -1848,6 +1913,11 @@ class ClauseList(ClauseElement): def get_children(self, **kwargs): return self.clauses + def _cache_key(self, **kw): + return (ClauseList, self.operator) + tuple( + clause._cache_key(**kw) for clause in self.clauses + ) + @property def _from_objects(self): return list(itertools.chain(*[c._from_objects for c in self.clauses])) @@ -1867,6 +1937,11 @@ class BooleanClauseList(ClauseList, ColumnElement): "BooleanClauseList has a private constructor" ) + def _cache_key(self, **kw): + return (BooleanClauseList, self.operator) + tuple( + clause._cache_key(**kw) for clause in self.clauses + ) + @classmethod def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): convert_clauses = [] @@ -2030,6 +2105,11 @@ class Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self,) + def _cache_key(self, **kw): + return (Tuple,) + tuple( + clause._cache_key(**kw) for clause in self.clauses + ) + def _bind_param(self, operator, obj, type_=None): return Tuple( *[ @@ -2245,6 +2325,24 @@ class Case(ColumnElement): if self.else_ is not None: yield self.else_ + def _cache_key(self, **kw): + return ( + ( + Case, + self.value._cache_key(**kw) + if self.value is not None + else None, + ) + + tuple( + (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens + ) + + ( + self.else_._cache_key(**kw) + if self.else_ is not None + else None, + ) + ) + @property def _from_objects(self): return list( @@ -2367,6 +2465,13 @@ class Cast(ColumnElement): def get_children(self, **kwargs): return self.clause, self.typeclause + def _cache_key(self, **kw): + return ( + Cast, + self.clause._cache_key(**kw), + self.typeclause._cache_key(**kw), + ) + @property def _from_objects(self): return self.clause._from_objects @@ -2461,6 +2566,9 @@ class TypeCoerce(ColumnElement): def get_children(self, **kwargs): return (self.clause,) + def _cache_key(self, **kw): + return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw)) + @property def _from_objects(self): return self.clause._from_objects @@ -2498,6 +2606,9 @@ class Extract(ColumnElement): def get_children(self, **kwargs): return (self.expr,) + def _cache_key(self, **kw): + return (Extract, self.field, self.expr._cache_key(**kw)) + @property def _from_objects(self): return self.expr._from_objects @@ -2524,6 +2635,9 @@ class _label_reference(ColumnElement): def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) + def _cache_key(self, **kw): + return (_label_reference, self.element._cache_key(**kw)) + def get_children(self, **kwargs): return [self.element] @@ -2542,6 +2656,9 @@ class _textual_label_reference(ColumnElement): def _text_clause(self): return TextClause._create_text(self.element) + def _cache_key(self, **kw): + return (_textual_label_reference, self.element) + class UnaryExpression(ColumnElement): """Define a 'unary' expression. @@ -2803,6 +2920,14 @@ class UnaryExpression(ColumnElement): def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) + def _cache_key(self, **kw): + return ( + UnaryExpression, + self.element._cache_key(**kw), + self.operator, + self.modifier, + ) + def get_children(self, **kwargs): return (self.element,) @@ -2941,6 +3066,15 @@ class AsBoolean(UnaryExpression): def self_group(self, against=None): return self + def _cache_key(self, **kw): + return ( + self.element._cache_key(**kw), + self.type._cache_key, + self.operator, + self.negate, + self.modifier, + ) + def _negate(self): if isinstance(self.element, (True_, False_)): return self.element._negate() @@ -3013,6 +3147,13 @@ class BinaryExpression(ColumnElement): def get_children(self, **kwargs): return self.left, self.right + def _cache_key(self, **kw): + return ( + BinaryExpression, + self.left._cache_key(**kw), + self.right._cache_key(**kw), + ) + def self_group(self, against=None): if operators.is_precedent(self.operator, against): return Grouping(self) @@ -3053,6 +3194,9 @@ class Slice(ColumnElement): assert against is operator.getitem return self + def _cache_key(self, **kw): + return (Slice, self.start, self.stop, self.step) + class IndexExpression(BinaryExpression): """Represent the class of expressions that are like an "index" operation. @@ -3091,6 +3235,9 @@ class Grouping(ColumnElement): def get_children(self, **kwargs): return (self.element,) + def _cache_key(self, **kw): + return (Grouping, self.element._cache_key(**kw)) + @property def _from_objects(self): return self.element._from_objects @@ -3297,6 +3444,16 @@ class Over(ColumnElement): if c is not None ] + def _cache_key(self, **kw): + return ( + (Over,) + + tuple( + e._cache_key(**kw) if e is not None else None + for e in (self.element, self.partition_by, self.order_by) + ) + + (self.range_, self.rows) + ) + def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) if self.partition_by is not None: @@ -3408,6 +3565,17 @@ class WithinGroup(ColumnElement): def get_children(self, **kwargs): return [c for c in (self.element, self.order_by) if c is not None] + def _cache_key(self, **kw): + return ( + WithinGroup, + self.element._cache_key(**kw) + if self.element is not None + else None, + self.order_by._cache_key(**kw) + if self.order_by is not None + else None, + ) + def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) if self.order_by is not None: @@ -3537,6 +3705,15 @@ class FunctionFilter(ColumnElement): if self.criterion is not None: self.criterion = clone(self.criterion, **kw) + def _cache_key(self, **kw): + return ( + FunctionFilter, + self.func._cache_key(**kw), + self.criterion._cache_key(**kw) + if self.criterion is not None + else None, + ) + @property def _from_objects(self): return list( @@ -3598,6 +3775,9 @@ class Label(ColumnElement): def __reduce__(self): return self.__class__, (self.name, self._element, self._type) + def _cache_key(self, **kw): + return (Label, self.element._cache_key(**kw), self._resolve_label) + @util.memoized_property def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean @@ -3831,6 +4011,14 @@ class ColumnClause(Immutable, ColumnElement): table = property(_get_table, _set_table) + def _cache_key(self, **kw): + return ( + self.name, + self.table.name if self.table is not None else None, + self.is_literal, + self.type._cache_key, + ) + @_memoized_property def _from_objects(self): t = self.table @@ -3946,6 +4134,9 @@ class CollationClause(ColumnElement): def __init__(self, collation): self.collation = collation + def _cache_key(self, **kw): + return (CollationClause, self.collation) + class _IdentifiedClause(Executable, ClauseElement): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index f48a20ec7f..0e92d5e502 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -256,6 +256,9 @@ class FunctionElement(Executable, ColumnElement, FromClause): def get_children(self, **kwargs): return (self.clause_expr,) + def _cache_key(self, **kw): + return (FunctionElement, self.clause_expr._cache_key(**kw)) + def _copy_internals(self, clone=_clone, **kw): self.clause_expr = clone(self.clause_expr, **kw) self._reset_exported() @@ -406,6 +409,14 @@ class FunctionAsBinary(BinaryExpression): def get_children(self, **kw): yield self.sql_function + def _cache_key(self, **kw): + return ( + FunctionAsBinary, + self.sql_function._cache_key(**kw), + self.left_index, + self.right_index, + ) + class _FunctionGenerator(object): """Generate :class:`.Function` objects based on getattr calls.""" @@ -566,6 +577,13 @@ class Function(FunctionElement): unique=True, ) + def _cache_key(self, **kw): + return ( + (Function,) + tuple(self.packagenames) + if self.packagenames + else () + (self.name, self.clause_expr._cache_key(**kw)) + ) + class _GenericMeta(VisitableType): def __init__(cls, clsname, bases, clsdict): @@ -684,6 +702,9 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq + def _cache_key(self, **kw): + return (next_value, self.sequence.name) + def compare(self, other, **kw): return ( isinstance(other, next_value) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 796e2b2720..a44e94da76 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -864,6 +864,16 @@ class Join(FromClause): def get_children(self, **kwargs): return self.left, self.right, self.onclause + def _cache_key(self, **kw): + return ( + Join, + self.isouter, + self.full, + self.left._cache_key(**kw), + self.right._cache_key(**kw), + self.onclause._cache_key(**kw), + ) + def _match_primaries(self, left, right): if isinstance(left, Join): left_right = left.right @@ -1289,6 +1299,7 @@ class Alias(FromClause): if self.supports_execution: self._execution_options = baseselectable._execution_options self.element = selectable + self._orig_name = name if name is None: if self.original.named_with_column: name = getattr(self.original, "name", None) @@ -1358,6 +1369,9 @@ class Alias(FromClause): yield c yield self.element + def _cache_key(self, **kw): + return (self.__class__, self.element._cache_key(**kw), self._orig_name) + @property def _from_objects(self): return [self] @@ -1777,6 +1791,9 @@ class FromGrouping(FromClause): def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) + def _cache_key(self, **kw): + return (FromGrouping, self.element._cache_key(**kw)) + @property def _from_objects(self): return self.element._from_objects @@ -1877,6 +1894,11 @@ class TableClause(Immutable, FromClause): else: return [] + def _cache_key(self, **kw): + return (TableClause, self.name) + tuple( + col._cache_key(**kw) for col in self._columns + ) + @util.dependencies("sqlalchemy.sql.dml") def insert(self, dml, values=None, inline=False, **kwargs): """Generate an :func:`.insert` construct against this @@ -2004,6 +2026,15 @@ class ForUpdateArg(ClauseElement): if self.of is not None: self.of = [clone(col, **kw) for col in self.of] + def _cache_key(self, **kw): + return ( + ForUpdateArg, + self.nowait, + self.read, + self.skip_locked, + self.of._cache_key(**kw) if self.of is not None else None, + ) + def __init__( self, nowait=False, @@ -2653,6 +2684,27 @@ class CompoundSelect(GenerativeSelect): + list(self.selects) ) + def _cache_key(self, **kw): + return ( + (CompoundSelect, self.keyword) + + tuple(stmt._cache_key(**kw) for stmt in self.selects) + + ( + self._order_by_clause._cache_key(**kw) + if self._order_by_clause is not None + else None, + ) + + ( + self._group_by_clause._cache_key(**kw) + if self._group_by_clause is not None + else None, + ) + + ( + self._for_update_arg._cache_key(**kw) + if self._for_update_arg is not None + else None, + ) + ) + def bind(self): if self._bind: return self._bind @@ -3277,6 +3329,47 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): ] ) + def _cache_key(self, **kw): + return ( + (Select,) + + ("raw_columns",) + + tuple(elem._cache_key(**kw) for elem in self._raw_columns) + + ("elements",) + + tuple( + elem._cache_key(**kw) if elem is not None else None + for elem in ( + self._whereclause, + self._having, + self._order_by_clause, + self._group_by_clause, + ) + ) + + ("from_obj",) + + tuple(elem._cache_key(**kw) for elem in self._from_obj) + + ("correlate",) + + tuple( + elem._cache_key(**kw) + for elem in ( + self._correlate if self._correlate is not None else () + ) + ) + + ("correlate_except",) + + tuple( + elem._cache_key(**kw) + for elem in ( + self._correlate_except + if self._correlate_except is not None + else () + ) + ) + + ("for_update",), + ( + self._for_update_arg._cache_key(**kw) + if self._for_update_arg is not None + else None, + ), + ) + @_generative def column(self, column): """return a new select() construct with the given column expression @@ -3950,6 +4043,11 @@ class TextAsFrom(SelectBase): yield c yield self.element + def _cache_key(self, **kw): + return (TextAsFrom, self.element._cache_key(**kw)) + tuple( + col._cache_key(**kw) for col in self.column_args + ) + def _scalar_type(self): return self.column_args[0].type diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9e052c6b43..bdeae96137 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -521,6 +521,10 @@ class TypeEngine(Visitable): def _gen_dialect_impl(self, dialect): return dialect.type_descriptor(self) + @util.memoized_property + def _cache_key(self): + return util.constructor_key(self, self.__class__) + def adapt(self, cls, **kw): """Produce an "adapted" form of this type, given an "impl" class to work with. diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 2f3deb1914..e6e4907abb 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -106,6 +106,7 @@ from .langhelpers import classproperty # noqa from .langhelpers import clsname_as_plain_name # noqa from .langhelpers import coerce_kw_type # noqa from .langhelpers import constructor_copy # noqa +from .langhelpers import constructor_key # noqa from .langhelpers import counter # noqa from .langhelpers import decode_slice # noqa from .langhelpers import decorator # noqa diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 117a9e2292..7d1321e0b8 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1133,6 +1133,17 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): kw[key] = type_(kw[key]) +def constructor_key(obj, cls): + """Produce a tuple structure that is cacheable using the __dict__ of + obj to retrieve values + + """ + names = get_cls_kwargs(cls) + return (cls,) + tuple( + (k, obj.__dict__[k]) for k in names if k in obj.__dict__ + ) + + def constructor_copy(obj, cls, *args, **kw): """Instantiate cls using the __dict__ of obj as constructor arguments. diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 8e62d5d82a..67072a6407 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -31,6 +31,7 @@ from sqlalchemy.sql import func from sqlalchemy.sql import operators from sqlalchemy.sql import True_ from sqlalchemy.sql import type_coerce +from sqlalchemy.sql import visitors from sqlalchemy.sql.elements import _label_reference from sqlalchemy.sql.elements import _textual_label_reference from sqlalchemy.sql.elements import Annotated @@ -47,9 +48,13 @@ from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.sql.selectable import _OffsetLimitParam from sqlalchemy.sql.selectable import FromGrouping from sqlalchemy.sql.selectable import Selectable +from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true +from sqlalchemy.testing import ne_ from sqlalchemy.util import class_hierarchy @@ -321,6 +326,109 @@ class CompareAndCopyTest(fixtures.TestBase): "%r == %r" % (case_a[a], case_b[b]), ) + def test_cache_key(self): + def assert_params_append(assert_params): + def append(param): + if param._value_required_for_cache: + assert_params.append(param) + else: + is_(param.value, None) + + return append + + for fixture in self.fixtures: + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + + assert_a_params = [] + assert_b_params = [] + + visitors.traverse_depthfirst( + case_a[a], + {}, + {"bindparam": assert_params_append(assert_a_params)}, + ) + visitors.traverse_depthfirst( + case_b[b], + {}, + {"bindparam": assert_params_append(assert_b_params)}, + ) + if assert_a_params: + assert_raises_message( + NotImplementedError, + "bindparams collection argument required ", + case_a[a]._cache_key, + ) + if assert_b_params: + assert_raises_message( + NotImplementedError, + "bindparams collection argument required ", + case_b[b]._cache_key, + ) + + if not assert_a_params and not assert_b_params: + if a == b: + eq_(case_a[a]._cache_key(), case_b[b]._cache_key()) + else: + ne_(case_a[a]._cache_key(), case_b[b]._cache_key()) + + def test_cache_key_gather_bindparams(self): + for fixture in self.fixtures: + case_a = fixture() + case_b = fixture() + + # in the "bindparams" case, the cache keys for bound parameters + # with only different values will be the same, but the params + # themselves are gathered into a collection. + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + a_params = {"bindparams": []} + b_params = {"bindparams": []} + if a == b: + a_key = case_a[a]._cache_key(**a_params) + b_key = case_b[b]._cache_key(**b_params) + eq_(a_key, b_key) + + if a_params["bindparams"]: + for a_param, b_param in zip( + a_params["bindparams"], b_params["bindparams"] + ): + assert a_param.compare(b_param) + else: + a_key = case_a[a]._cache_key(**a_params) + b_key = case_b[b]._cache_key(**b_params) + + if a_key == b_key: + for a_param, b_param in zip( + a_params["bindparams"], b_params["bindparams"] + ): + if not a_param.compare(b_param): + break + else: + assert False, "Bound parameters are all the same" + else: + ne_(a_key, b_key) + + assert_a_params = [] + assert_b_params = [] + visitors.traverse_depthfirst( + case_a[a], {}, {"bindparam": assert_a_params.append} + ) + visitors.traverse_depthfirst( + case_b[b], {}, {"bindparam": assert_b_params.append} + ) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be deterministic + # and matches what a traversal would provide. + eq_(a_params["bindparams"], assert_a_params) + eq_(b_params["bindparams"], assert_b_params) + def test_compare_col_identity(self): stmt1 = ( select([table_a.c.a, table_b.c.b])