From: Mike Bayer Date: Tue, 24 Jan 2023 16:05:12 +0000 (-0500) Subject: add set_shard_id() loader option for horizontal shard X-Git-Tag: rel_2_0_0~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8a32f367175871500723c5ebfc0f1af1564d3478;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add set_shard_id() loader option for horizontal shard Added new option to horizontal sharding API :class:`_horizontal.set_shard_id` which sets the effective shard identifier to query against, for both the primary query as well as for all secondary loaders including relationship eager loaders as well as relationship and column lazy loaders. Modernize sharding examples with new-style mappings, add new asyncio example. Fixes: #7226 Fixes: #7028 Change-Id: Ie69248060c305e8de04f75a529949777944ad511 --- diff --git a/doc/build/changelog/unreleased_20/7226.rst b/doc/build/changelog/unreleased_20/7226.rst new file mode 100644 index 0000000000..ef643aff0f --- /dev/null +++ b/doc/build/changelog/unreleased_20/7226.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: feature, orm extensions + :tickets: 7226 + + Added new option to horizontal sharding API + :class:`_horizontal.set_shard_id` which sets the effective shard identifier + to query against, for both the primary query as well as for all secondary + loaders including relationship eager loaders as well as relationship and + column lazy loaders. diff --git a/doc/build/orm/extensions/horizontal_shard.rst b/doc/build/orm/extensions/horizontal_shard.rst index 69faf9bb33..b0467f1abe 100644 --- a/doc/build/orm/extensions/horizontal_shard.rst +++ b/doc/build/orm/extensions/horizontal_shard.rst @@ -11,6 +11,9 @@ API Documentation .. autoclass:: ShardedSession :members: +.. autoclass:: set_shard_id + :members: + .. autoclass:: ShardedQuery :members: diff --git a/examples/sharding/asyncio.py b/examples/sharding/asyncio.py new file mode 100644 index 0000000000..a66689a5bc --- /dev/null +++ b/examples/sharding/asyncio.py @@ -0,0 +1,351 @@ +"""Illustrates sharding API used with asyncio. + +For the sync version of this example, see separate_databases.py. + +Most of the code here is copied from separate_databases.py and works +in exactly the same way. The main change is how the +``async_sessionmaker`` is configured, and as is specific to this example +the routine that generates new primary keys. + +""" +from __future__ import annotations + +import asyncio +import datetime + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import Table +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.horizontal_shard import set_shard_id +from sqlalchemy.ext.horizontal_shard import ShardedSession +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import immediateload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.sql import operators +from sqlalchemy.sql import visitors + + +echo = True +db1 = create_async_engine("sqlite+aiosqlite://", echo=echo) +db2 = create_async_engine("sqlite+aiosqlite://", echo=echo) +db3 = create_async_engine("sqlite+aiosqlite://", echo=echo) +db4 = create_async_engine("sqlite+aiosqlite://", echo=echo) + + +# for asyncio, the ShardedSession class is passed +# via sync_session_class. The shards themselves are used within +# implicit-awaited internals, so we use the sync_engine Engine objects +# in the shards dictionary. +Session = async_sessionmaker( + sync_session_class=ShardedSession, + expire_on_commit=False, + shards={ + "north_america": db1.sync_engine, + "asia": db2.sync_engine, + "europe": db3.sync_engine, + "south_america": db4.sync_engine, + }, +) + + +# mappings and tables +class Base(DeclarativeBase): + pass + + +# we need a way to create identifiers which are unique across all databases. +# one easy way would be to just use a composite primary key, where one value +# is the shard id. but here, we'll show something more "generic", an id +# generation function. we'll use a simplistic "id table" stored in database +# #1. Any other method will do just as well; UUID, hilo, application-specific, +# etc. + +ids = Table("ids", Base.metadata, Column("nextid", Integer, nullable=False)) + + +def id_generator(ctx): + # id_generator is run within a "synchronous" context, where + # we use an implicit-await API that will convert back to explicit await + # calls when it reaches the driver. + with db1.sync_engine.begin() as conn: + nextid = conn.scalar(ids.select().with_for_update()) + conn.execute(ids.update().values({ids.c.nextid: ids.c.nextid + 1})) + return nextid + + +# table setup. we'll store a lead table of continents/cities, and a secondary +# table storing locations. a particular row will be placed in the database +# whose shard id corresponds to the 'continent'. in this setup, secondary rows +# in 'weather_reports' will be placed in the same DB as that of the parent, but +# this can be changed if you're willing to write more complex sharding +# functions. + + +class WeatherLocation(Base): + __tablename__ = "weather_locations" + + id: Mapped[int] = mapped_column(primary_key=True, default=id_generator) + continent: Mapped[str] + city: Mapped[str] + + reports: Mapped[list[Report]] = relationship(back_populates="location") + + def __init__(self, continent: str, city: str): + self.continent = continent + self.city = city + + +class Report(Base): + __tablename__ = "weather_reports" + + id: Mapped[int] = mapped_column(primary_key=True) + location_id: Mapped[int] = mapped_column( + ForeignKey("weather_locations.id") + ) + temperature: Mapped[float] + report_time: Mapped[datetime.datetime] = mapped_column( + default=datetime.datetime.now + ) + + location: Mapped[WeatherLocation] = relationship(back_populates="reports") + + def __init__(self, temperature: float): + self.temperature = temperature + + +# step 5. define sharding functions. + +# we'll use a straight mapping of a particular set of "country" +# attributes to shard id. +shard_lookup = { + "North America": "north_america", + "Asia": "asia", + "Europe": "europe", + "South America": "south_america", +} + + +def shard_chooser(mapper, instance, clause=None): + """shard chooser. + + looks at the given instance and returns a shard id + note that we need to define conditions for + the WeatherLocation class, as well as our secondary Report class which will + point back to its WeatherLocation via its 'location' attribute. + + """ + if isinstance(instance, WeatherLocation): + return shard_lookup[instance.continent] + else: + return shard_chooser(mapper, instance.location) + + +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. + + given a primary key, returns a list of shards + to search. here, we don't have any particular information from a + pk so we just return all shard ids. often, you'd want to do some + kind of round-robin strategy here so that requests are evenly + distributed among DBs. + + """ + if lazy_loaded_from: + # if we are in a lazy load, we can look at the parent object + # and limit our search to that same shard, assuming that's how we've + # set things up. + return [lazy_loaded_from.identity_token] + else: + return ["north_america", "asia", "europe", "south_america"] + + +def execute_chooser(context): + """statement execution chooser. + + this also returns a list of shard ids, which can just be all of them. but + here we'll search into the execution context in order to try to narrow down + the list of shards to SELECT. + + """ + ids = [] + + # we'll grab continent names as we find them + # and convert to shard ids + for column, operator, value in _get_select_comparisons(context.statement): + # "shares_lineage()" returns True if both columns refer to the same + # statement column, adjusting for any annotations present. + # (an annotation is an internal clone of a Column object + # and occur when using ORM-mapped attributes like + # "WeatherLocation.continent"). A simpler comparison, though less + # accurate, would be "column.key == 'continent'". + if column.shares_lineage(WeatherLocation.__table__.c.continent): + if operator == operators.eq: + ids.append(shard_lookup[value]) + elif operator == operators.in_op: + ids.extend(shard_lookup[v] for v in value) + + if len(ids) == 0: + return ["north_america", "asia", "europe", "south_america"] + else: + return ids + + +def _get_select_comparisons(statement): + """Search a Select or Query object for binary expressions. + + Returns expressions which match a Column against one or more + literal values as a list of tuples of the form + (column, operator, values). "values" is a single value + or tuple of values depending on the operator. + + """ + binds = {} + clauses = set() + comparisons = [] + + def visit_bindparam(bind): + # visit a bind parameter. + + value = bind.effective_value + binds[bind] = value + + def visit_column(column): + clauses.add(column) + + def visit_binary(binary): + if binary.left in clauses and binary.right in binds: + comparisons.append( + (binary.left, binary.operator, binds[binary.right]) + ) + + elif binary.left in binds and binary.right in clauses: + comparisons.append( + (binary.right, binary.operator, binds[binary.left]) + ) + + # here we will traverse through the query's criterion, searching + # for SQL constructs. We will place simple column comparisons + # into a list. + if statement.whereclause is not None: + visitors.traverse( + statement.whereclause, + {}, + { + "bindparam": visit_bindparam, + "binary": visit_binary, + "column": visit_column, + }, + ) + return comparisons + + +# further configure create_session to use these functions +Session.configure( + shard_chooser=shard_chooser, + identity_chooser=identity_chooser, + execute_chooser=execute_chooser, +) + + +async def setup(): + # create tables + for db in (db1, db2, db3, db4): + async with db.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # establish initial "id" in db1 + async with db1.begin() as conn: + await conn.execute(ids.insert(), {"nextid": 1}) + + +async def main(): + await setup() + + # save and load objects! + + tokyo = WeatherLocation("Asia", "Tokyo") + newyork = WeatherLocation("North America", "New York") + toronto = WeatherLocation("North America", "Toronto") + london = WeatherLocation("Europe", "London") + dublin = WeatherLocation("Europe", "Dublin") + brasilia = WeatherLocation("South America", "Brasila") + quito = WeatherLocation("South America", "Quito") + + tokyo.reports.append(Report(80.0)) + newyork.reports.append(Report(75)) + quito.reports.append(Report(85)) + + async with Session() as sess: + + sess.add_all( + [tokyo, newyork, toronto, london, dublin, brasilia, quito] + ) + + await sess.commit() + + t = await sess.get( + WeatherLocation, + tokyo.id, + options=[immediateload(WeatherLocation.reports)], + ) + assert t.city == tokyo.city + assert t.reports[0].temperature == 80.0 + + # select across shards + asia_and_europe = ( + await sess.execute( + select(WeatherLocation).filter( + WeatherLocation.continent.in_(["Europe", "Asia"]) + ) + ) + ).scalars() + + assert {c.city for c in asia_and_europe} == { + "Tokyo", + "London", + "Dublin", + } + + # optionally set a shard id for the query and all related loaders + north_american_cities_w_t = ( + await sess.execute( + select(WeatherLocation) + .filter(WeatherLocation.city.startswith("T")) + .options(set_shard_id("north_america")) + ) + ).scalars() + + # Tokyo not included since not in the north_america shard + assert {c.city for c in north_american_cities_w_t} == { + "Toronto", + } + + # the Report class uses a simple integer primary key. So across two + # databases, a primary key will be repeated. The "identity_token" + # tracks in memory that these two identical primary keys are local to + # different shards. + newyork_report = newyork.reports[0] + tokyo_report = tokyo.reports[0] + + assert inspect(newyork_report).identity_key == ( + Report, + (1,), + "north_america", + ) + assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + + # the token representing the originating shard is also available + # directly + assert inspect(newyork_report).identity_token == "north_america" + assert inspect(tokyo_report).identity_token == "asia" + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/sharding/separate_databases.py b/examples/sharding/separate_databases.py index fe92fd3bac..65364773b7 100644 --- a/examples/sharding/separate_databases.py +++ b/examples/sharding/separate_databases.py @@ -1,19 +1,20 @@ """Illustrates sharding using distinct SQLite databases.""" +from __future__ import annotations import datetime from sqlalchemy import Column from sqlalchemy import create_engine -from sqlalchemy import DateTime -from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select -from sqlalchemy import String from sqlalchemy import Table -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.horizontal_shard import set_shard_id from sqlalchemy.ext.horizontal_shard import ShardedSession +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import sessionmaker from sqlalchemy.sql import operators @@ -41,7 +42,9 @@ Session = sessionmaker( # mappings and tables -Base = declarative_base() +class Base(DeclarativeBase): + pass + # we need a way to create identifiers which are unique across all databases. # one easy way would be to just use a composite primary key, where one value @@ -72,13 +75,13 @@ def id_generator(ctx): class WeatherLocation(Base): __tablename__ = "weather_locations" - id = Column(Integer, primary_key=True, default=id_generator) - continent = Column(String(30), nullable=False) - city = Column(String(50), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True, default=id_generator) + continent: Mapped[str] + city: Mapped[str] - reports = relationship("Report", backref="location") + reports: Mapped[list[Report]] = relationship(back_populates="location") - def __init__(self, continent, city): + def __init__(self, continent: str, city: str): self.continent = continent self.city = city @@ -86,29 +89,22 @@ class WeatherLocation(Base): class Report(Base): __tablename__ = "weather_reports" - id = Column(Integer, primary_key=True) - location_id = Column( - "location_id", Integer, ForeignKey("weather_locations.id") + id: Mapped[int] = mapped_column(primary_key=True) + location_id: Mapped[int] = mapped_column( + ForeignKey("weather_locations.id") ) - temperature = Column("temperature", Float) - report_time = Column( - "report_time", DateTime, default=datetime.datetime.now + temperature: Mapped[float] + report_time: Mapped[datetime.datetime] = mapped_column( + default=datetime.datetime.now ) - def __init__(self, temperature): - self.temperature = temperature - + location: Mapped[WeatherLocation] = relationship(back_populates="reports") -# create tables -for db in (db1, db2, db3, db4): - Base.metadata.create_all(db) - -# establish initial "id" in db1 -with db1.begin() as conn: - conn.execute(ids.insert(), {"nextid": 1}) + def __init__(self, temperature: float): + self.temperature = temperature -# step 5. define sharding functions. +# define sharding functions. # we'll use a straight mapping of a particular set of "country" # attributes to shard id. @@ -241,61 +237,90 @@ Session.configure( execute_chooser=execute_chooser, ) -# save and load objects! -tokyo = WeatherLocation("Asia", "Tokyo") -newyork = WeatherLocation("North America", "New York") -toronto = WeatherLocation("North America", "Toronto") -london = WeatherLocation("Europe", "London") -dublin = WeatherLocation("Europe", "Dublin") -brasilia = WeatherLocation("South America", "Brasila") -quito = WeatherLocation("South America", "Quito") +def setup(): + # create tables + for db in (db1, db2, db3, db4): + Base.metadata.create_all(db) -tokyo.reports.append(Report(80.0)) -newyork.reports.append(Report(75)) -quito.reports.append(Report(85)) + # establish initial "id" in db1 + with db1.begin() as conn: + conn.execute(ids.insert(), {"nextid": 1}) -with Session() as sess: - sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito]) +def main(): + setup() - sess.commit() + # save and load objects! - t = sess.get(WeatherLocation, tokyo.id) - assert t.city == tokyo.city - assert t.reports[0].temperature == 80.0 + tokyo = WeatherLocation("Asia", "Tokyo") + newyork = WeatherLocation("North America", "New York") + toronto = WeatherLocation("North America", "Toronto") + london = WeatherLocation("Europe", "London") + dublin = WeatherLocation("Europe", "Dublin") + brasilia = WeatherLocation("South America", "Brasila") + quito = WeatherLocation("South America", "Quito") - north_american_cities = sess.execute( - select(WeatherLocation).filter( - WeatherLocation.continent == "North America" - ) - ).scalars() + tokyo.reports.append(Report(80.0)) + newyork.reports.append(Report(75)) + quito.reports.append(Report(85)) - assert {c.city for c in north_american_cities} == {"New York", "Toronto"} + with Session() as sess: - asia_and_europe = sess.execute( - select(WeatherLocation).filter( - WeatherLocation.continent.in_(["Europe", "Asia"]) + sess.add_all( + [tokyo, newyork, toronto, london, dublin, brasilia, quito] ) - ).scalars() - assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"} + sess.commit() - # the Report class uses a simple integer primary key. So across two - # databases, a primary key will be repeated. The "identity_token" tracks - # in memory that these two identical primary keys are local to different - # databases. - newyork_report = newyork.reports[0] - tokyo_report = tokyo.reports[0] + t = sess.get(WeatherLocation, tokyo.id) + assert t.city == tokyo.city + assert t.reports[0].temperature == 80.0 - assert inspect(newyork_report).identity_key == ( - Report, - (1,), - "north_america", - ) - assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + # select across shards + asia_and_europe = sess.execute( + select(WeatherLocation).filter( + WeatherLocation.continent.in_(["Europe", "Asia"]) + ) + ).scalars() + + assert {c.city for c in asia_and_europe} == { + "Tokyo", + "London", + "Dublin", + } + + # optionally set a shard id for the query and all related loaders + north_american_cities_w_t = sess.execute( + select(WeatherLocation) + .filter(WeatherLocation.city.startswith("T")) + .options(set_shard_id("north_america")) + ).scalars() + + # Tokyo not included since not in the north_america shard + assert {c.city for c in north_american_cities_w_t} == { + "Toronto", + } + + # the Report class uses a simple integer primary key. So across two + # databases, a primary key will be repeated. The "identity_token" + # tracks in memory that these two identical primary keys are local to + # different shards. + newyork_report = newyork.reports[0] + tokyo_report = tokyo.reports[0] + + assert inspect(newyork_report).identity_key == ( + Report, + (1,), + "north_america", + ) + assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + + # the token representing the originating shard is also available + # directly + assert inspect(newyork_report).identity_token == "north_america" + assert inspect(tokyo_report).identity_token == "asia" - # the token representing the originating shard is also available directly - assert inspect(newyork_report).identity_token == "north_america" - assert inspect(tokyo_report).identity_token == "asia" +if __name__ == "__main__": + main() diff --git a/examples/sharding/separate_schema_translates.py b/examples/sharding/separate_schema_translates.py index f7bdc62500..0b5b08e57f 100644 --- a/examples/sharding/separate_schema_translates.py +++ b/examples/sharding/separate_schema_translates.py @@ -4,20 +4,20 @@ where a different "schema_translates_map" can be used for each shard. In this example we will set a "shard id" at all times. """ +from __future__ import annotations + import datetime import os -from sqlalchemy import Column from sqlalchemy import create_engine -from sqlalchemy import DateTime -from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import inspect -from sqlalchemy import Integer from sqlalchemy import select -from sqlalchemy import String -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.horizontal_shard import set_shard_id from sqlalchemy.ext.horizontal_shard import ShardedSession +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import sessionmaker @@ -55,7 +55,8 @@ Session = sessionmaker( # mappings and tables -Base = declarative_base() +class Base(DeclarativeBase): + pass # table setup. we'll store a lead table of continents/cities, and a secondary @@ -69,13 +70,13 @@ Base = declarative_base() class WeatherLocation(Base): __tablename__ = "weather_locations" - id = Column(Integer, primary_key=True) - continent = Column(String(30), nullable=False) - city = Column(String(50), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + continent: Mapped[str] + city: Mapped[str] - reports = relationship("Report", backref="location") + reports: Mapped[list[Report]] = relationship(back_populates="location") - def __init__(self, continent, city): + def __init__(self, continent: str, city: str): self.continent = continent self.city = city @@ -83,25 +84,22 @@ class WeatherLocation(Base): class Report(Base): __tablename__ = "weather_reports" - id = Column(Integer, primary_key=True) - location_id = Column( - "location_id", Integer, ForeignKey("weather_locations.id") + id: Mapped[int] = mapped_column(primary_key=True) + location_id: Mapped[int] = mapped_column( + ForeignKey("weather_locations.id") ) - temperature = Column("temperature", Float) - report_time = Column( - "report_time", DateTime, default=datetime.datetime.now + temperature: Mapped[float] + report_time: Mapped[datetime.datetime] = mapped_column( + default=datetime.datetime.now ) - def __init__(self, temperature): - self.temperature = temperature - + location: Mapped[WeatherLocation] = relationship(back_populates="reports") -# create tables -for db in (db1, db2, db3, db4): - Base.metadata.create_all(db) + def __init__(self, temperature: float): + self.temperature = temperature -# step 5. define sharding functions. +# define sharding functions. # we'll use a straight mapping of a particular set of "country" # attributes to shard id. @@ -154,15 +152,11 @@ def execute_chooser(context): given an :class:`.ORMExecuteState` for a statement, return a list of shards we should consult. - As before, we want a "shard_id" execution option to be present. - Otherwise, this would be a lazy load from a parent object where we - will look for the previous token. - """ if context.lazy_loaded_from: return [context.lazy_loaded_from.identity_token] else: - return [context.execution_options["shard_id"]] + return ["north_america", "asia", "europe", "south_america"] # configure shard chooser @@ -172,70 +166,90 @@ Session.configure( execute_chooser=execute_chooser, ) -# save and load objects! -tokyo = WeatherLocation("Asia", "Tokyo") -newyork = WeatherLocation("North America", "New York") -toronto = WeatherLocation("North America", "Toronto") -london = WeatherLocation("Europe", "London") -dublin = WeatherLocation("Europe", "Dublin") -brasilia = WeatherLocation("South America", "Brasila") -quito = WeatherLocation("South America", "Quito") +def setup(): + # create tables + for db in (db1, db2, db3, db4): + Base.metadata.create_all(db) -tokyo.reports.append(Report(80.0)) -newyork.reports.append(Report(75)) -quito.reports.append(Report(85)) -with Session() as sess: +def main(): + setup() - sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito]) + # save and load objects! - sess.commit() + tokyo = WeatherLocation("Asia", "Tokyo") + newyork = WeatherLocation("North America", "New York") + toronto = WeatherLocation("North America", "Toronto") + london = WeatherLocation("Europe", "London") + dublin = WeatherLocation("Europe", "Dublin") + brasilia = WeatherLocation("South America", "Brasila") + quito = WeatherLocation("South America", "Quito") - t = sess.get( - WeatherLocation, - tokyo.id, - # for session.get(), we currently need to use identity_token. - # the horizontal sharding API does not yet pass through the - # execution options - identity_token="asia", - # future version - # execution_options={"shard_id": "asia"} - ) - assert t.city == tokyo.city - assert t.reports[0].temperature == 80.0 - - north_american_cities = sess.execute( - select(WeatherLocation).filter( - WeatherLocation.continent == "North America" - ), - execution_options={"shard_id": "north_america"}, - ).scalars() - - assert {c.city for c in north_american_cities} == {"New York", "Toronto"} - - europe = sess.execute( - select(WeatherLocation).filter(WeatherLocation.continent == "Europe"), - execution_options={"shard_id": "europe"}, - ).scalars() - - assert {c.city for c in europe} == {"London", "Dublin"} - - # the Report class uses a simple integer primary key. So across two - # databases, a primary key will be repeated. The "identity_token" tracks - # in memory that these two identical primary keys are local to different - # databases. - newyork_report = newyork.reports[0] - tokyo_report = tokyo.reports[0] - - assert inspect(newyork_report).identity_key == ( - Report, - (1,), - "north_america", - ) - assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + tokyo.reports.append(Report(80.0)) + newyork.reports.append(Report(75)) + quito.reports.append(Report(85)) + + with Session() as sess: + + sess.add_all( + [tokyo, newyork, toronto, london, dublin, brasilia, quito] + ) + + sess.commit() + + t = sess.get( + WeatherLocation, + tokyo.id, + identity_token="asia", + ) + assert t.city == tokyo.city + assert t.reports[0].temperature == 80.0 + + # select across shards + asia_and_europe = sess.execute( + select(WeatherLocation).filter( + WeatherLocation.continent.in_(["Europe", "Asia"]) + ) + ).scalars() + + assert {c.city for c in asia_and_europe} == { + "Tokyo", + "London", + "Dublin", + } + + # optionally set a shard id for the query and all related loaders + north_american_cities_w_t = sess.execute( + select(WeatherLocation) + .filter(WeatherLocation.city.startswith("T")) + .options(set_shard_id("north_america")) + ).scalars() + + # Tokyo not included since not in the north_america shard + assert {c.city for c in north_american_cities_w_t} == { + "Toronto", + } + + # the Report class uses a simple integer primary key. So across two + # databases, a primary key will be repeated. The "identity_token" + # tracks in memory that these two identical primary keys are local to + # different shards. + newyork_report = newyork.reports[0] + tokyo_report = tokyo.reports[0] + + assert inspect(newyork_report).identity_key == ( + Report, + (1,), + "north_america", + ) + assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + + # the token representing the originating shard is also available + # directly + assert inspect(newyork_report).identity_token == "north_america" + assert inspect(tokyo_report).identity_token == "asia" - # the token representing the originating shard is also available directly - assert inspect(newyork_report).identity_token == "north_america" - assert inspect(tokyo_report).identity_token == "asia" +if __name__ == "__main__": + main() diff --git a/examples/sharding/separate_tables.py b/examples/sharding/separate_tables.py index 97c6a07f6a..98db3771f7 100644 --- a/examples/sharding/separate_tables.py +++ b/examples/sharding/separate_tables.py @@ -1,27 +1,27 @@ """Illustrates sharding using a single SQLite database, that will however have multiple tables using a naming convention.""" +from __future__ import annotations import datetime from sqlalchemy import Column from sqlalchemy import create_engine -from sqlalchemy import DateTime from sqlalchemy import event -from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select -from sqlalchemy import String from sqlalchemy import Table -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.horizontal_shard import set_shard_id from sqlalchemy.ext.horizontal_shard import ShardedSession +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import sessionmaker from sqlalchemy.sql import operators from sqlalchemy.sql import visitors - echo = True engine = create_engine("sqlite://", echo=echo) @@ -55,7 +55,9 @@ Session = sessionmaker( # mappings and tables -Base = declarative_base() +class Base(DeclarativeBase): + pass + # we need a way to create identifiers which are unique across all databases. # one easy way would be to just use a composite primary key, where one value @@ -86,13 +88,13 @@ def id_generator(ctx): class WeatherLocation(Base): __tablename__ = "_prefix__weather_locations" - id = Column(Integer, primary_key=True, default=id_generator) - continent = Column(String(30), nullable=False) - city = Column(String(50), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True, default=id_generator) + continent: Mapped[str] + city: Mapped[str] - reports = relationship("Report", backref="location") + reports: Mapped[list[Report]] = relationship(back_populates="location") - def __init__(self, continent, city): + def __init__(self, continent: str, city: str): self.continent = continent self.city = city @@ -100,29 +102,22 @@ class WeatherLocation(Base): class Report(Base): __tablename__ = "_prefix__weather_reports" - id = Column(Integer, primary_key=True) - location_id = Column( - "location_id", Integer, ForeignKey("_prefix__weather_locations.id") + id: Mapped[int] = mapped_column(primary_key=True) + location_id: Mapped[int] = mapped_column( + ForeignKey("_prefix__weather_locations.id") ) - temperature = Column("temperature", Float) - report_time = Column( - "report_time", DateTime, default=datetime.datetime.now + temperature: Mapped[float] + report_time: Mapped[datetime.datetime] = mapped_column( + default=datetime.datetime.now ) - def __init__(self, temperature): - self.temperature = temperature - + location: Mapped[WeatherLocation] = relationship(back_populates="reports") -# create tables -for db in (db1, db2, db3, db4): - Base.metadata.create_all(db) - -# establish initial "id" in db1 -with db1.begin() as conn: - conn.execute(ids.insert(), {"nextid": 1}) + def __init__(self, temperature: float): + self.temperature = temperature -# step 5. define sharding functions. +# define sharding functions. # we'll use a straight mapping of a particular set of "country" # attributes to shard id. @@ -255,61 +250,89 @@ Session.configure( execute_chooser=execute_chooser, ) -# save and load objects! -tokyo = WeatherLocation("Asia", "Tokyo") -newyork = WeatherLocation("North America", "New York") -toronto = WeatherLocation("North America", "Toronto") -london = WeatherLocation("Europe", "London") -dublin = WeatherLocation("Europe", "Dublin") -brasilia = WeatherLocation("South America", "Brasila") -quito = WeatherLocation("South America", "Quito") +def setup(): + # create tables + for db in (db1, db2, db3, db4): + Base.metadata.create_all(db) -tokyo.reports.append(Report(80.0)) -newyork.reports.append(Report(75)) -quito.reports.append(Report(85)) + # establish initial "id" in db1 + with db1.begin() as conn: + conn.execute(ids.insert(), {"nextid": 1}) -with Session() as sess: - sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito]) +def main(): + setup() - sess.commit() + # save and load objects! - t = sess.get(WeatherLocation, tokyo.id) - assert t.city == tokyo.city - assert t.reports[0].temperature == 80.0 + tokyo = WeatherLocation("Asia", "Tokyo") + newyork = WeatherLocation("North America", "New York") + toronto = WeatherLocation("North America", "Toronto") + london = WeatherLocation("Europe", "London") + dublin = WeatherLocation("Europe", "Dublin") + brasilia = WeatherLocation("South America", "Brasila") + quito = WeatherLocation("South America", "Quito") - north_american_cities = sess.execute( - select(WeatherLocation).filter( - WeatherLocation.continent == "North America" - ) - ).scalars() + tokyo.reports.append(Report(80.0)) + newyork.reports.append(Report(75)) + quito.reports.append(Report(85)) - assert {c.city for c in north_american_cities} == {"New York", "Toronto"} + with Session() as sess: - asia_and_europe = sess.execute( - select(WeatherLocation).filter( - WeatherLocation.continent.in_(["Europe", "Asia"]) + sess.add_all( + [tokyo, newyork, toronto, london, dublin, brasilia, quito] ) - ).scalars() - assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"} + sess.commit() - # the Report class uses a simple integer primary key. So across two - # databases, a primary key will be repeated. The "identity_token" tracks - # in memory that these two identical primary keys are local to different - # databases. - newyork_report = newyork.reports[0] - tokyo_report = tokyo.reports[0] + t = sess.get(WeatherLocation, tokyo.id) + assert t.city == tokyo.city + assert t.reports[0].temperature == 80.0 - assert inspect(newyork_report).identity_key == ( - Report, - (1,), - "north_america", - ) - assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + # optionally set a shard id for the query and all related loaders + north_american_cities_w_t = sess.execute( + select(WeatherLocation) + .filter(WeatherLocation.city.startswith("T")) + .options(set_shard_id("north_america")) + ).scalars() + + # Tokyo not included since not in the north_america shard + assert {c.city for c in north_american_cities_w_t} == { + "Toronto", + } + + asia_and_europe = sess.execute( + select(WeatherLocation).filter( + WeatherLocation.continent.in_(["Europe", "Asia"]) + ) + ).scalars() + + assert {c.city for c in asia_and_europe} == { + "Tokyo", + "London", + "Dublin", + } + + # the Report class uses a simple integer primary key. So across two + # databases, a primary key will be repeated. The "identity_token" + # tracks in memory that these two identical primary keys are local to + # different shards. + newyork_report = newyork.reports[0] + tokyo_report = tokyo.reports[0] + + assert inspect(newyork_report).identity_key == ( + Report, + (1,), + "north_america", + ) + assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") + + # the token representing the originating shard is also available + # directly + assert inspect(newyork_report).identity_token == "north_america" + assert inspect(tokyo_report).identity_token == "asia" - # the token representing the originating shard is also available directly - assert inspect(newyork_report).identity_token == "north_america" - assert inspect(tokyo_report).identity_token == "asia" +if __name__ == "__main__": + main() diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index df9d0d797b..0bcc5628fb 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -42,6 +42,7 @@ from .. import inspect from .. import util from ..orm import PassiveFlag from ..orm._typing import OrmExecuteOptionsParameter +from ..orm.interfaces import ORMOption from ..orm.mapper import Mapper from ..orm.query import Query from ..orm.session import _BindArguments @@ -73,7 +74,7 @@ _T = TypeVar("_T", bound=Any) SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]") -_ShardKey = str +ShardIdentifier = str class ShardChooser(Protocol): @@ -105,8 +106,7 @@ class ShardedQuery(Query[_T]): .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy :class:`.Query` class. The :class:`.ShardedSession` now supports - 2.0 style execution via the :meth:`.ShardedSession.execute` method - as well. + 2.0 style execution via the :meth:`.ShardedSession.execute` method. """ @@ -118,7 +118,9 @@ class ShardedQuery(Query[_T]): self.execute_chooser = self.session.execute_chooser self._shard_id = None - def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery: + def set_shard( + self: SelfShardedQuery, shard_id: ShardIdentifier + ) -> SelfShardedQuery: """Return a new query, limited to a single shard ID. All subsequent operations with the returned query will @@ -166,14 +168,12 @@ class ShardedSession(Session): should set whatever state on the instance to mark it in the future as participating in that shard. - :param id_chooser: A callable, passed a :class:`.ShardedQuery` and a - tuple of identity values, which should return a list of shard ids - where the ID might reside. The databases will be queried in the order - of this listing. + :param identity_chooser: A callable, passed a Mapper and primary key + argument, which should return a list of shard ids where this + primary key might reside. - .. legacy:: This parameter still uses the legacy - :class:`.ShardedQuery` class as an argument passed to the - callable. + .. versionchanged:: 2.0 The ``identity_chooser`` parameter + supersedes the ``id_chooser`` parameter. :param execute_chooser: For a given :class:`.ORMExecuteState`, returns the list of shard_ids @@ -250,7 +250,7 @@ class ShardedSession(Session): "execute_chooser or query_chooser is required" ) self.execute_chooser = execute_chooser - self.__shards: Dict[_ShardKey, _SessionBind] = {} + self.__shards: Dict[ShardIdentifier, _SessionBind] = {} if shards is not None: for k in shards: self.bind_shard(k, shards[k]) @@ -329,7 +329,7 @@ class ShardedSession(Session): self, mapper: Optional[Mapper[_T]] = None, instance: Optional[Any] = None, - shard_id: Optional[Any] = None, + shard_id: Optional[ShardIdentifier] = None, **kw: Any, ) -> Connection: """Provide a :class:`_engine.Connection` to use in the unit of work @@ -359,7 +359,7 @@ class ShardedSession(Session): self, mapper: Optional[_EntityBindKey[_O]] = None, *, - shard_id: Optional[_ShardKey] = None, + shard_id: Optional[ShardIdentifier] = None, instance: Optional[Any] = None, clause: Optional[ClauseElement] = None, **kw: Any, @@ -372,11 +372,61 @@ class ShardedSession(Session): return self.__shards[shard_id] def bind_shard( - self, shard_id: _ShardKey, bind: Union[Engine, OptionEngine] + self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine] ) -> None: self.__shards[shard_id] = bind +class set_shard_id(ORMOption): + """a loader option for statements to apply a specific shard id to the + primary query as well as for additional relationship and column + loaders. + + The :class:`_horizontal.set_shard_id` option may be applied using + the :meth:`_sql.Executable.options` method of any executable statement:: + + stmt = ( + select(MyObject). + where(MyObject.name == 'some name'). + options(set_shard_id("shard1")) + ) + + Above, the statement when invoked will limit to the "shard1" shard + identifier for the primary query as well as for all relationship and + column loading strategies, including eager loaders such as + :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`, + and the lazy relationship loader :func:`_orm.lazyload`. + + In this way, the :class:`_horizontal.set_shard_id` option has much wider + scope than using the "shard_id" argument within the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. + + + .. versionadded:: 2.0.0 + + """ + + __slots__ = ("shard_id", "propagate_to_loaders") + + def __init__( + self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True + ): + """Construct a :class:`_horizontal.set_shard_id` option. + + :param shard_id: shard identifier + :param propagate_to_loaders: if left at its default of ``True``, the + shard option will take place for lazy loaders such as + :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option + will not be propagated to loaded objects. Note that :func:`_orm.defer` + always limits to the shard_id of the parent row in any case, so the + parameter only has a net effect on the behavior of the + :func:`_orm.lazyload` strategy. + + """ + self.shard_id = shard_id + self.propagate_to_loaders = propagate_to_loaders + + def execute_and_instances( orm_context: ORMExecuteState, ) -> Union[Result[_T], IteratorResult[_TP]]: @@ -400,7 +450,7 @@ def execute_and_instances( assert isinstance(session, ShardedSession) def iter_for_shard( - shard_id: str, + shard_id: ShardIdentifier, ) -> Union[Result[_T], IteratorResult[_TP]]: bind_arguments = dict(orm_context.bind_arguments) @@ -409,14 +459,22 @@ def execute_and_instances( orm_context.update_execution_options(identity_token=shard_id) return orm_context.invoke_statement(bind_arguments=bind_arguments) - if active_options and active_options._identity_token is not None: - shard_id = active_options._identity_token - elif "_sa_shard_id" in orm_context.execution_options: - shard_id = orm_context.execution_options["_sa_shard_id"] - elif "shard_id" in orm_context.bind_arguments: - shard_id = orm_context.bind_arguments["shard_id"] + for orm_opt in orm_context._non_compile_orm_options: + # TODO: if we had an ORMOption that gets applied at ORM statement + # execution time, that would allow this to be more generalized. + # for now just iterate and look for our options + if isinstance(orm_opt, set_shard_id): + shard_id = orm_opt.shard_id + break else: - shard_id = None + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None if shard_id is not None: return iter_for_shard(shard_id) diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 76483265d4..36dc6ddb9c 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from .decl_api import registry as _registry_type from .interfaces import InspectionAttr from .interfaces import MapperProperty + from .interfaces import ORMOption from .interfaces import UserDefinedOption from .mapper import Mapper from .relationships import RelationshipProperty @@ -122,6 +123,12 @@ class _LoaderCallable(Protocol): ... +def is_orm_option( + opt: ExecutableOption, +) -> TypeGuard[ORMOption]: + return not opt._is_core # type: ignore + + def is_user_defined_option( opt: ExecutableOption, ) -> TypeGuard[UserDefinedOption]: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5bcb22a083..babd5d72ad 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -46,6 +46,7 @@ from . import state as statelib from ._typing import _O from ._typing import insp_is_mapper from ._typing import is_composite_class +from ._typing import is_orm_option from ._typing import is_user_defined_option from .base import _class_to_mapper from .base import _none_set @@ -730,6 +731,14 @@ class ORMExecuteState(util.MemoizedSlots): bulk_persistence.BulkUDCompileState.default_update_options, ) + @property + def _non_compile_orm_options(self) -> Sequence[ORMOption]: + return [ + opt + for opt in self.statement._with_options + if is_orm_option(opt) and not opt._is_compile_state + ] + @property def user_defined_options(self) -> Sequence[UserDefinedOption]: """The sequence of :class:`.UserDefinedOptions` that have been diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 5dd32e8cc4..c66ba71c3f 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -906,7 +906,17 @@ class AssertsExecutionResults: db, callable_, assertsql.CountStatements(count) ) - def assert_multiple_sql_count(self, dbs, callable_, counts): + @contextlib.contextmanager + def assert_execution(self, db, *rules): + with self.sql_execution_asserter(db) as asserter: + yield + asserter.assert_(*rules) + + def assert_statement_count(self, db, count): + return self.assert_execution(db, assertsql.CountStatements(count)) + + @contextlib.contextmanager + def assert_statement_count_multi_db(self, dbs, counts): recs = [ (self.sql_execution_asserter(db), db, count) for (db, count) in zip(dbs, counts) @@ -915,21 +925,12 @@ class AssertsExecutionResults: for ctx, db, count in recs: asserters.append(ctx.__enter__()) try: - return callable_() + yield finally: for asserter, (ctx, db, count) in zip(asserters, recs): ctx.__exit__(None, None, None) asserter.assert_(assertsql.CountStatements(count)) - @contextlib.contextmanager - def assert_execution(self, db, *rules): - with self.sql_execution_asserter(db) as asserter: - yield - asserter.assert_(*rules) - - def assert_statement_count(self, db, count): - return self.assert_execution(db, assertsql.CountStatements(count)) - class ComparesIndexes: def compare_table_index_with_expected( diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 8e5d09cab0..467c61f160 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -17,9 +17,12 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update from sqlalchemy import util +from sqlalchemy.ext.horizontal_shard import set_shard_id from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.orm import clear_mappers +from sqlalchemy.orm import defer from sqlalchemy.orm import deferred +from sqlalchemy.orm import lazyload from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session @@ -83,7 +86,7 @@ class ShardTest: def setup_test(self): global db1, db2, db3, db4 - db1, db2, db3, db4 = self._dbs = self._init_dbs() + db1, db2, db3, db4 = self._dbs = self.dbs = self._init_dbs() for db in (db1, db2, db3, db4): self.tables_test_metadata.create_all(db) @@ -245,6 +248,103 @@ class ShardTest: t2 = sess.get(WeatherLocation, 1) eq_(t2.city, "Tokyo") + @testing.variation("option_type", ["none", "lazyload", "selectinload"]) + @testing.variation( + "limit_shard", + ["none", "lead_only", "propagate_to_loaders", "bind_arg"], + ) + def test_set_shard_option_relationship(self, option_type, limit_shard): + sess = self._fixture_data() + + stmt = select(WeatherLocation).filter( + WeatherLocation.city == "New York" + ) + + bind_arguments = {} + + if limit_shard.none: + # right now selectinload / lazyload runs all the shards even if the + # ids are limited to just one shard, since that information + # is not transferred + counts = [2, 2, 2, 2] + elif limit_shard.lead_only: + if option_type.selectinload: + counts = [2, 0, 0, 0] + else: + counts = [2, 1, 1, 1] + elif limit_shard.bind_arg: + counts = [2, 1, 1, 1] + elif limit_shard.propagate_to_loaders: + counts = [2, 0, 0, 0] + else: + limit_shard.fail() + + if option_type.lazyload: + stmt = stmt.options(lazyload(WeatherLocation.reports)) + elif option_type.selectinload: + stmt = stmt.options(selectinload(WeatherLocation.reports)) + + if limit_shard.lead_only: + stmt = stmt.options( + set_shard_id("north_america", propagate_to_loaders=False) + ) + elif limit_shard.propagate_to_loaders: + stmt = stmt.options(set_shard_id("north_america")) + elif limit_shard.bind_arg: + bind_arguments["shard_id"] = "north_america" + + with self.assert_statement_count_multi_db(self.dbs, counts): + w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first() + w1.reports + + @testing.variation("option_type", ["none", "defer"]) + @testing.variation( + "limit_shard", + ["none", "lead_only", "propagate_to_loaders", "bind_arg"], + ) + def test_set_shard_option_column(self, option_type, limit_shard): + sess = self._fixture_data() + + stmt = select(WeatherLocation).filter( + WeatherLocation.city == "New York" + ) + + bind_arguments = {} + + if limit_shard.none: + if option_type.defer: + counts = [2, 1, 1, 1] + else: + counts = [1, 1, 1, 1] + elif limit_shard.lead_only or limit_shard.propagate_to_loaders: + if option_type.defer: + counts = [2, 0, 0, 0] + else: + counts = [1, 0, 0, 0] + elif limit_shard.bind_arg: + if option_type.defer: + counts = [2, 0, 0, 0] + else: + counts = [1, 0, 0, 0] + else: + limit_shard.fail() + + if option_type.defer: + stmt = stmt.options(defer(WeatherLocation.continent)) + + if limit_shard.lead_only: + stmt = stmt.options( + set_shard_id("north_america", propagate_to_loaders=False) + ) + elif limit_shard.propagate_to_loaders: + stmt = stmt.options(set_shard_id("north_america")) + elif limit_shard.bind_arg: + bind_arguments["shard_id"] = "north_america" + + with self.assert_statement_count_multi_db(self.dbs, counts): + w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first() + w1.continent + def test_query_explicit_shard_via_bind_opts(self): sess = self._fixture_data() @@ -502,14 +602,9 @@ class ShardTest: .execution_options(synchronize_session=synchronize_session) ) - # test synchronize session - def go(): + with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]): eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0}) - self.assert_sql_count( - sess._ShardedSession__shards["north_america"], go, 0 - ) - eq_( {row.temperature for row in sess.query(Report.temperature)}, {86.0, 75.0, 91.0}, @@ -536,15 +631,11 @@ class ShardTest: .execution_options(synchronize_session=synchronize_session) ) - def go(): + with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]): # test synchronize session for t in temps: assert inspect(t).deleted is (t.temperature >= 80) - self.assert_sql_count( - sess._ShardedSession__shards["north_america"], go, 0 - ) - eq_( {row.temperature for row in sess.query(Report.temperature)}, {75.0}, @@ -704,16 +795,15 @@ class TableNameConventionShardTest(ShardTest, fixtures.MappedTest): schema = "changeme" def _init_dbs(self): - db1 = testing_engine( - "sqlite://", options={"execution_options": {"shard_id": "shard1"}} - ) - db2 = db1.execution_options(shard_id="shard2") - db3 = db1.execution_options(shard_id="shard3") - db4 = db1.execution_options(shard_id="shard4") + dbmain = testing_engine("sqlite://") + db1 = dbmain.execution_options(shard_id="shard1") + db2 = dbmain.execution_options(shard_id="shard2") + db3 = dbmain.execution_options(shard_id="shard3") + db4 = dbmain.execution_options(shard_id="shard4") import re - @event.listens_for(db1, "before_cursor_execute", retval=True) + @event.listens_for(dbmain, "before_cursor_execute", retval=True) def _switch_shard(conn, cursor, stmt, params, context, executemany): shard_id = conn._execution_options["shard_id"] # because SQLite can't just give us a "use" statement, we have @@ -977,12 +1067,10 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): session.expire(page, ["book"]) - def go(): + with self.assert_statement_count_multi_db(self.dbs, [0, 0]): + # doesn't emit SQL eq_(page.book, book) - # doesn't emit SQL - self.assert_multiple_sql_count(self.dbs, go, [0, 0]) - def test_lazy_load_from_db(self): session = self._fixture(lazy_load_book=True) @@ -999,12 +1087,10 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): book1_page = session.query(Page).first() session.expire(book1_page, ["book"]) - def go(): + with self.assert_statement_count_multi_db(self.dbs, [1, 0]): + # emits one query eq_(inspect(book1_page.book).identity_key, book1_id) - # emits one query - self.assert_multiple_sql_count(self.dbs, go, [1, 0]) - def test_lazy_load_no_baked_conflict(self): session = self._fixture(lazy_load_pages=True)