From: Mike Bayer Date: Thu, 9 Mar 2023 18:54:07 +0000 (-0500) Subject: repair broken lambda patch X-Git-Tag: rel_2_0_6~6^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=2c9796b10c3e85450afdeedc4003607abda2f2db;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git repair broken lambda patch in I4e0b627bfa187f1780dc68ec81b94db1c78f846a the 1.4 version has more changes than the main version, which failed to get the entire change, yet the whole thing was merged. Restore the missing mutex related code to the main version. Fixed regression where the fix for :ticket:`8098`, which was released in the 1.4 series and provided a layer of concurrency-safe checks for the lambda SQL API, included additional fixes in the patch that failed to be applied to the main branch. These additional fixes have been applied. Change-Id: Id172e09c421dafa6ef1d40b383aa4371de343864 References: #8098 Fixes: #9461 --- diff --git a/doc/build/changelog/unreleased_20/9461.rst b/doc/build/changelog/unreleased_20/9461.rst new file mode 100644 index 0000000000..3397cfe274 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9461.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 9461 + + Fixed regression where the fix for :ticket:`8098`, which was released in + the 1.4 series and provided a layer of concurrency-safe checks for the + lambda SQL API, included additional fixes in the patch that failed to be + applied to the main branch. These additional fixes have been applied. diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 04bf86ee60..12175c75d1 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -272,11 +272,16 @@ class LambdaElement(elements.ClauseElement): if rec is None: if cache_key is not _cache_key.NO_CACHE: - rec = AnalyzedFunction( - tracker, self, apply_propagate_attrs, fn - ) - rec.closure_bindparams = bindparams - lambda_cache[tracker_key + cache_key] = rec + with AnalyzedCode._generation_mutex: + key = tracker_key + cache_key + if key not in lambda_cache: + rec = AnalyzedFunction( + tracker, self, apply_propagate_attrs, fn + ) + rec.closure_bindparams = bindparams + lambda_cache[key] = rec + else: + rec = lambda_cache[key] else: rec = NonAnalyzedFunction(self._invoke_user_fn(fn)) diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index dca0bb063b..c84bc1c78e 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -657,7 +657,7 @@ class AutomapInhTest(fixtures.MappedTest): class ConcurrentAutomapTest(fixtures.TestBase): - __only_on__ = "sqlite" + __only_on__ = "sqlite+pysqlite" def _make_tables(self, e): m = MetaData() diff --git a/test/requirements.py b/test/requirements.py index 9d51ae4777..67ecdc4059 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -367,22 +367,22 @@ class DefaultRequirements(SuiteRequirements): Target must support simultaneous, independent database connections. """ - # This is also true of some configurations of UnixODBC and probably - # win32 ODBC as well. + # note: **do not** let any sqlite driver run "independent connection" + # tests. Use independent_readonly_connections for a concurrency + # related test that only uses reads to use sqlite + return skip_if(["sqlite"]) + + @property + def independent_readonly_connections(self): + """ + Target must support simultaneous, independent database connections + that will be used in a readonly fashion. + + """ return skip_if( [ - no_support( - "sqlite", - "independent connections disabled " - "when :memory: connections are used", - ), - exclude( - "mssql", - "<", - (9, 0, 0), - "SQL Server 2005+ is required for " - "independent connections", - ), + self._sqlite_memory_db, + "+aiosqlite", ] ) diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index c3e271706f..002a13db94 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +import threading +import time +from typing import List +from typing import Optional + from sqlalchemy import exc from sqlalchemy import testing from sqlalchemy.future import select as future_select @@ -2083,3 +2090,98 @@ class DeferredLambdaElementTest( eq_(e12key[0], e1key[0]) eq_(e32key[0], e3key[0]) + + +class ConcurrencyTest(fixtures.TestBase): + """test for #8098 and #9461""" + + __requires__ = ("independent_readonly_connections",) + + __only_on__ = ("+psycopg2", "+mysqldb", "+pysqlite", "+pymysql") + + THREADS = 10 + + @testing.fixture + def mapping_fixture(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + col1 = Column(String(100)) + col2 = Column(String(100)) + col3 = Column(String(100)) + col4 = Column(String(100)) + + decl_base.metadata.create_all(testing.db) + + from sqlalchemy.orm import Session + + with testing.db.connect() as conn: + with Session(conn) as session: + session.add_all( + [ + A(col1=str(i), col2=str(i), col3=str(i), col4=str(i)) + for i in range(self.THREADS + 1) + ] + ) + session.commit() + + return A + + @testing.requires.timing_intensive + def test_lambda_concurrency(self, testing_engine, mapping_fixture): + A = mapping_fixture + engine = testing_engine(options={"pool_size": self.THREADS + 5}) + NUM_OF_LAMBDAS = 150 + + code = """ +from sqlalchemy import lambda_stmt, select + + +def generate_lambda_stmt(wanted): + stmt = lambda_stmt(lambda: select(A.col1, A.col2, A.col3, A.col4)) +""" + + for _ in range(NUM_OF_LAMBDAS): + code += ( + " stmt += lambda s: s.where((A.col1 == wanted) & " + "(A.col2 == wanted) & (A.col3 == wanted) & " + "(A.col4 == wanted))\n" + ) + + code += """ + return stmt +""" + + d = {"A": A, "__name__": "lambda_fake"} + exec(code, d) + generate_lambda_stmt = d["generate_lambda_stmt"] + + runs: List[Optional[int]] = [None for _ in range(self.THREADS)] + conns = [engine.connect() for _ in range(self.THREADS)] + + def run(num): + wanted = str(num) + connection = conns[num] + time.sleep(0.1) + stmt = generate_lambda_stmt(wanted) + time.sleep(0.1) + row = connection.execute(stmt).first() + if not row: + runs[num] = False + else: + runs[num] = True + + threads = [ + threading.Thread(target=run, args=(num,)) + for num in range(self.THREADS) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=10) + for conn in conns: + conn.close() + + fails = len([r for r in runs if r is False]) + assert not fails, f"{fails} runs failed"