]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add set_shard_id() loader option for horizontal shard
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jan 2023 16:05:12 +0000 (11:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 26 Jan 2023 00:42:14 +0000 (19:42 -0500)
Added new option to horizontal sharding API
:class:`_horizontal.set_shard_id` which sets the effective shard identifier
to query against, for both the primary query as well as for all secondary
loaders including relationship eager loaders as well as relationship and
column lazy loaders.

Modernize sharding examples with new-style mappings, add new asyncio example.

Fixes: #7226
Fixes: #7028
Change-Id: Ie69248060c305e8de04f75a529949777944ad511

doc/build/changelog/unreleased_20/7226.rst [new file with mode: 0644]
doc/build/orm/extensions/horizontal_shard.rst
examples/sharding/asyncio.py [new file with mode: 0644]
examples/sharding/separate_databases.py
examples/sharding/separate_schema_translates.py
examples/sharding/separate_tables.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/testing/assertions.py
test/ext/test_horizontal_shard.py

diff --git a/doc/build/changelog/unreleased_20/7226.rst b/doc/build/changelog/unreleased_20/7226.rst
new file mode 100644 (file)
index 0000000..ef643af
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: feature, orm extensions
+    :tickets: 7226
+
+    Added new option to horizontal sharding API
+    :class:`_horizontal.set_shard_id` which sets the effective shard identifier
+    to query against, for both the primary query as well as for all secondary
+    loaders including relationship eager loaders as well as relationship and
+    column lazy loaders.
index 69faf9bb33d90094e9c6f942d2c1e1b434bc828c..b0467f1abe54e99cb680bb52e31f75e7383ed3b1 100644 (file)
@@ -11,6 +11,9 @@ API Documentation
 .. autoclass:: ShardedSession
    :members:
 
+.. autoclass:: set_shard_id
+   :members:
+
 .. autoclass:: ShardedQuery
    :members:
 
diff --git a/examples/sharding/asyncio.py b/examples/sharding/asyncio.py
new file mode 100644 (file)
index 0000000..a66689a
--- /dev/null
@@ -0,0 +1,351 @@
+"""Illustrates sharding API used with asyncio.
+
+For the sync version of this example, see separate_databases.py.
+
+Most of the code here is copied from separate_databases.py and works
+in exactly the same way.   The main change is how the
+``async_sessionmaker`` is configured, and as is specific to this example
+the routine that generates new primary keys.
+
+"""
+from __future__ import annotations
+
+import asyncio
+import datetime
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.horizontal_shard import set_shard_id
+from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import immediateload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.sql import operators
+from sqlalchemy.sql import visitors
+
+
+echo = True
+db1 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+db2 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+db3 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+db4 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+
+
+# for asyncio, the ShardedSession class is passed
+# via sync_session_class.  The shards themselves are used within
+# implicit-awaited internals, so we use the sync_engine Engine objects
+# in the shards dictionary.
+Session = async_sessionmaker(
+    sync_session_class=ShardedSession,
+    expire_on_commit=False,
+    shards={
+        "north_america": db1.sync_engine,
+        "asia": db2.sync_engine,
+        "europe": db3.sync_engine,
+        "south_america": db4.sync_engine,
+    },
+)
+
+
+# mappings and tables
+class Base(DeclarativeBase):
+    pass
+
+
+# we need a way to create identifiers which are unique across all databases.
+# one easy way would be to just use a composite primary key, where  one value
+# is the shard id.  but here, we'll show something more "generic", an id
+# generation function.  we'll use a simplistic "id table" stored in database
+# #1.  Any other method will do just as well; UUID, hilo, application-specific,
+# etc.
+
+ids = Table("ids", Base.metadata, Column("nextid", Integer, nullable=False))
+
+
+def id_generator(ctx):
+    # id_generator is run within a "synchronous" context, where
+    # we use an implicit-await API that will convert back to explicit await
+    # calls when it reaches the driver.
+    with db1.sync_engine.begin() as conn:
+        nextid = conn.scalar(ids.select().with_for_update())
+        conn.execute(ids.update().values({ids.c.nextid: ids.c.nextid + 1}))
+    return nextid
+
+
+# table setup.  we'll store a lead table of continents/cities, and a secondary
+# table storing locations. a particular row will be placed in the database
+# whose shard id corresponds to the 'continent'.  in this setup, secondary rows
+# in 'weather_reports' will be placed in the same DB as that of the parent, but
+# this can be changed if you're willing to write more complex sharding
+# functions.
+
+
+class WeatherLocation(Base):
+    __tablename__ = "weather_locations"
+
+    id: Mapped[int] = mapped_column(primary_key=True, default=id_generator)
+    continent: Mapped[str]
+    city: Mapped[str]
+
+    reports: Mapped[list[Report]] = relationship(back_populates="location")
+
+    def __init__(self, continent: str, city: str):
+        self.continent = continent
+        self.city = city
+
+
+class Report(Base):
+    __tablename__ = "weather_reports"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    location_id: Mapped[int] = mapped_column(
+        ForeignKey("weather_locations.id")
+    )
+    temperature: Mapped[float]
+    report_time: Mapped[datetime.datetime] = mapped_column(
+        default=datetime.datetime.now
+    )
+
+    location: Mapped[WeatherLocation] = relationship(back_populates="reports")
+
+    def __init__(self, temperature: float):
+        self.temperature = temperature
+
+
+# step 5. define sharding functions.
+
+# we'll use a straight mapping of a particular set of "country"
+# attributes to shard id.
+shard_lookup = {
+    "North America": "north_america",
+    "Asia": "asia",
+    "Europe": "europe",
+    "South America": "south_america",
+}
+
+
+def shard_chooser(mapper, instance, clause=None):
+    """shard chooser.
+
+    looks at the given instance and returns a shard id
+    note that we need to define conditions for
+    the WeatherLocation class, as well as our secondary Report class which will
+    point back to its WeatherLocation via its 'location' attribute.
+
+    """
+    if isinstance(instance, WeatherLocation):
+        return shard_lookup[instance.continent]
+    else:
+        return shard_chooser(mapper, instance.location)
+
+
+def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw):
+    """identity chooser.
+
+    given a primary key, returns a list of shards
+    to search.  here, we don't have any particular information from a
+    pk so we just return all shard ids. often, you'd want to do some
+    kind of round-robin strategy here so that requests are evenly
+    distributed among DBs.
+
+    """
+    if lazy_loaded_from:
+        # if we are in a lazy load, we can look at the parent object
+        # and limit our search to that same shard, assuming that's how we've
+        # set things up.
+        return [lazy_loaded_from.identity_token]
+    else:
+        return ["north_america", "asia", "europe", "south_america"]
+
+
+def execute_chooser(context):
+    """statement execution chooser.
+
+    this also returns a list of shard ids, which can just be all of them. but
+    here we'll search into the execution context in order to try to narrow down
+    the list of shards to SELECT.
+
+    """
+    ids = []
+
+    # we'll grab continent names as we find them
+    # and convert to shard ids
+    for column, operator, value in _get_select_comparisons(context.statement):
+        # "shares_lineage()" returns True if both columns refer to the same
+        # statement column, adjusting for any annotations present.
+        # (an annotation is an internal clone of a Column object
+        # and occur when using ORM-mapped attributes like
+        # "WeatherLocation.continent"). A simpler comparison, though less
+        # accurate, would be "column.key == 'continent'".
+        if column.shares_lineage(WeatherLocation.__table__.c.continent):
+            if operator == operators.eq:
+                ids.append(shard_lookup[value])
+            elif operator == operators.in_op:
+                ids.extend(shard_lookup[v] for v in value)
+
+    if len(ids) == 0:
+        return ["north_america", "asia", "europe", "south_america"]
+    else:
+        return ids
+
+
+def _get_select_comparisons(statement):
+    """Search a Select or Query object for binary expressions.
+
+    Returns expressions which match a Column against one or more
+    literal values as a list of tuples of the form
+    (column, operator, values).   "values" is a single value
+    or tuple of values depending on the operator.
+
+    """
+    binds = {}
+    clauses = set()
+    comparisons = []
+
+    def visit_bindparam(bind):
+        # visit a bind parameter.
+
+        value = bind.effective_value
+        binds[bind] = value
+
+    def visit_column(column):
+        clauses.add(column)
+
+    def visit_binary(binary):
+        if binary.left in clauses and binary.right in binds:
+            comparisons.append(
+                (binary.left, binary.operator, binds[binary.right])
+            )
+
+        elif binary.left in binds and binary.right in clauses:
+            comparisons.append(
+                (binary.right, binary.operator, binds[binary.left])
+            )
+
+    # here we will traverse through the query's criterion, searching
+    # for SQL constructs.  We will place simple column comparisons
+    # into a list.
+    if statement.whereclause is not None:
+        visitors.traverse(
+            statement.whereclause,
+            {},
+            {
+                "bindparam": visit_bindparam,
+                "binary": visit_binary,
+                "column": visit_column,
+            },
+        )
+    return comparisons
+
+
+# further configure create_session to use these functions
+Session.configure(
+    shard_chooser=shard_chooser,
+    identity_chooser=identity_chooser,
+    execute_chooser=execute_chooser,
+)
+
+
+async def setup():
+    # create tables
+    for db in (db1, db2, db3, db4):
+        async with db.begin() as conn:
+            await conn.run_sync(Base.metadata.create_all)
+
+    # establish initial "id" in db1
+    async with db1.begin() as conn:
+        await conn.execute(ids.insert(), {"nextid": 1})
+
+
+async def main():
+    await setup()
+
+    # save and load objects!
+
+    tokyo = WeatherLocation("Asia", "Tokyo")
+    newyork = WeatherLocation("North America", "New York")
+    toronto = WeatherLocation("North America", "Toronto")
+    london = WeatherLocation("Europe", "London")
+    dublin = WeatherLocation("Europe", "Dublin")
+    brasilia = WeatherLocation("South America", "Brasila")
+    quito = WeatherLocation("South America", "Quito")
+
+    tokyo.reports.append(Report(80.0))
+    newyork.reports.append(Report(75))
+    quito.reports.append(Report(85))
+
+    async with Session() as sess:
+
+        sess.add_all(
+            [tokyo, newyork, toronto, london, dublin, brasilia, quito]
+        )
+
+        await sess.commit()
+
+        t = await sess.get(
+            WeatherLocation,
+            tokyo.id,
+            options=[immediateload(WeatherLocation.reports)],
+        )
+        assert t.city == tokyo.city
+        assert t.reports[0].temperature == 80.0
+
+        # select across shards
+        asia_and_europe = (
+            await sess.execute(
+                select(WeatherLocation).filter(
+                    WeatherLocation.continent.in_(["Europe", "Asia"])
+                )
+            )
+        ).scalars()
+
+        assert {c.city for c in asia_and_europe} == {
+            "Tokyo",
+            "London",
+            "Dublin",
+        }
+
+        # optionally set a shard id for the query and all related loaders
+        north_american_cities_w_t = (
+            await sess.execute(
+                select(WeatherLocation)
+                .filter(WeatherLocation.city.startswith("T"))
+                .options(set_shard_id("north_america"))
+            )
+        ).scalars()
+
+        # Tokyo not included since not in the north_america shard
+        assert {c.city for c in north_american_cities_w_t} == {
+            "Toronto",
+        }
+
+        # the Report class uses a simple integer primary key.  So across two
+        # databases, a primary key will be repeated.  The "identity_token"
+        # tracks in memory that these two identical primary keys are local to
+        # different shards.
+        newyork_report = newyork.reports[0]
+        tokyo_report = tokyo.reports[0]
+
+        assert inspect(newyork_report).identity_key == (
+            Report,
+            (1,),
+            "north_america",
+        )
+        assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+        # the token representing the originating shard is also available
+        # directly
+        assert inspect(newyork_report).identity_token == "north_america"
+        assert inspect(tokyo_report).identity_token == "asia"
+
+
+if __name__ == "__main__":
+    asyncio.run(main())
index fe92fd3bac43b6c2fc7c434c3aecb9c2fcc80d6c..65364773b72bfeff5e79778bc2c1a806941ade33 100644 (file)
@@ -1,19 +1,20 @@
 """Illustrates sharding using distinct SQLite databases."""
+from __future__ import annotations
 
 import datetime
 
 from sqlalchemy import Column
 from sqlalchemy import create_engine
-from sqlalchemy import DateTime
-from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import select
-from sqlalchemy import String
 from sqlalchemy import Table
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.horizontal_shard import set_shard_id
 from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.sql import operators
@@ -41,7 +42,9 @@ Session = sessionmaker(
 
 
 # mappings and tables
-Base = declarative_base()
+class Base(DeclarativeBase):
+    pass
+
 
 # we need a way to create identifiers which are unique across all databases.
 # one easy way would be to just use a composite primary key, where  one value
@@ -72,13 +75,13 @@ def id_generator(ctx):
 class WeatherLocation(Base):
     __tablename__ = "weather_locations"
 
-    id = Column(Integer, primary_key=True, default=id_generator)
-    continent = Column(String(30), nullable=False)
-    city = Column(String(50), nullable=False)
+    id: Mapped[int] = mapped_column(primary_key=True, default=id_generator)
+    continent: Mapped[str]
+    city: Mapped[str]
 
-    reports = relationship("Report", backref="location")
+    reports: Mapped[list[Report]] = relationship(back_populates="location")
 
-    def __init__(self, continent, city):
+    def __init__(self, continent: str, city: str):
         self.continent = continent
         self.city = city
 
@@ -86,29 +89,22 @@ class WeatherLocation(Base):
 class Report(Base):
     __tablename__ = "weather_reports"
 
-    id = Column(Integer, primary_key=True)
-    location_id = Column(
-        "location_id", Integer, ForeignKey("weather_locations.id")
+    id: Mapped[int] = mapped_column(primary_key=True)
+    location_id: Mapped[int] = mapped_column(
+        ForeignKey("weather_locations.id")
     )
-    temperature = Column("temperature", Float)
-    report_time = Column(
-        "report_time", DateTime, default=datetime.datetime.now
+    temperature: Mapped[float]
+    report_time: Mapped[datetime.datetime] = mapped_column(
+        default=datetime.datetime.now
     )
 
-    def __init__(self, temperature):
-        self.temperature = temperature
-
+    location: Mapped[WeatherLocation] = relationship(back_populates="reports")
 
-# create tables
-for db in (db1, db2, db3, db4):
-    Base.metadata.create_all(db)
-
-# establish initial "id" in db1
-with db1.begin() as conn:
-    conn.execute(ids.insert(), {"nextid": 1})
+    def __init__(self, temperature: float):
+        self.temperature = temperature
 
 
-# step 5. define sharding functions.
+# define sharding functions.
 
 # we'll use a straight mapping of a particular set of "country"
 # attributes to shard id.
@@ -241,61 +237,90 @@ Session.configure(
     execute_chooser=execute_chooser,
 )
 
-# save and load objects!
 
-tokyo = WeatherLocation("Asia", "Tokyo")
-newyork = WeatherLocation("North America", "New York")
-toronto = WeatherLocation("North America", "Toronto")
-london = WeatherLocation("Europe", "London")
-dublin = WeatherLocation("Europe", "Dublin")
-brasilia = WeatherLocation("South America", "Brasila")
-quito = WeatherLocation("South America", "Quito")
+def setup():
+    # create tables
+    for db in (db1, db2, db3, db4):
+        Base.metadata.create_all(db)
 
-tokyo.reports.append(Report(80.0))
-newyork.reports.append(Report(75))
-quito.reports.append(Report(85))
+    # establish initial "id" in db1
+    with db1.begin() as conn:
+        conn.execute(ids.insert(), {"nextid": 1})
 
-with Session() as sess:
 
-    sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
+def main():
+    setup()
 
-    sess.commit()
+    # save and load objects!
 
-    t = sess.get(WeatherLocation, tokyo.id)
-    assert t.city == tokyo.city
-    assert t.reports[0].temperature == 80.0
+    tokyo = WeatherLocation("Asia", "Tokyo")
+    newyork = WeatherLocation("North America", "New York")
+    toronto = WeatherLocation("North America", "Toronto")
+    london = WeatherLocation("Europe", "London")
+    dublin = WeatherLocation("Europe", "Dublin")
+    brasilia = WeatherLocation("South America", "Brasila")
+    quito = WeatherLocation("South America", "Quito")
 
-    north_american_cities = sess.execute(
-        select(WeatherLocation).filter(
-            WeatherLocation.continent == "North America"
-        )
-    ).scalars()
+    tokyo.reports.append(Report(80.0))
+    newyork.reports.append(Report(75))
+    quito.reports.append(Report(85))
 
-    assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
+    with Session() as sess:
 
-    asia_and_europe = sess.execute(
-        select(WeatherLocation).filter(
-            WeatherLocation.continent.in_(["Europe", "Asia"])
+        sess.add_all(
+            [tokyo, newyork, toronto, london, dublin, brasilia, quito]
         )
-    ).scalars()
 
-    assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"}
+        sess.commit()
 
-    # the Report class uses a simple integer primary key.  So across two
-    # databases, a primary key will be repeated.  The "identity_token" tracks
-    # in memory that these two identical primary keys are local to different
-    # databases.
-    newyork_report = newyork.reports[0]
-    tokyo_report = tokyo.reports[0]
+        t = sess.get(WeatherLocation, tokyo.id)
+        assert t.city == tokyo.city
+        assert t.reports[0].temperature == 80.0
 
-    assert inspect(newyork_report).identity_key == (
-        Report,
-        (1,),
-        "north_america",
-    )
-    assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+        # select across shards
+        asia_and_europe = sess.execute(
+            select(WeatherLocation).filter(
+                WeatherLocation.continent.in_(["Europe", "Asia"])
+            )
+        ).scalars()
+
+        assert {c.city for c in asia_and_europe} == {
+            "Tokyo",
+            "London",
+            "Dublin",
+        }
+
+        # optionally set a shard id for the query and all related loaders
+        north_american_cities_w_t = sess.execute(
+            select(WeatherLocation)
+            .filter(WeatherLocation.city.startswith("T"))
+            .options(set_shard_id("north_america"))
+        ).scalars()
+
+        # Tokyo not included since not in the north_america shard
+        assert {c.city for c in north_american_cities_w_t} == {
+            "Toronto",
+        }
+
+        # the Report class uses a simple integer primary key.  So across two
+        # databases, a primary key will be repeated.  The "identity_token"
+        # tracks in memory that these two identical primary keys are local to
+        # different shards.
+        newyork_report = newyork.reports[0]
+        tokyo_report = tokyo.reports[0]
+
+        assert inspect(newyork_report).identity_key == (
+            Report,
+            (1,),
+            "north_america",
+        )
+        assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+        # the token representing the originating shard is also available
+        # directly
+        assert inspect(newyork_report).identity_token == "north_america"
+        assert inspect(tokyo_report).identity_token == "asia"
 
-    # the token representing the originating shard is also available directly
 
-    assert inspect(newyork_report).identity_token == "north_america"
-    assert inspect(tokyo_report).identity_token == "asia"
+if __name__ == "__main__":
+    main()
index f7bdc62500eb80acff89ce2ff9041af835f95761..0b5b08e57f6e70e2dbd78702eeb8866139e59e60 100644 (file)
@@ -4,20 +4,20 @@ where a different "schema_translates_map" can be used for each shard.
 In this example we will set a "shard id" at all times.
 
 """
+from __future__ import annotations
+
 import datetime
 import os
 
-from sqlalchemy import Column
 from sqlalchemy import create_engine
-from sqlalchemy import DateTime
-from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
-from sqlalchemy import Integer
 from sqlalchemy import select
-from sqlalchemy import String
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.horizontal_shard import set_shard_id
 from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import sessionmaker
 
@@ -55,7 +55,8 @@ Session = sessionmaker(
 
 
 # mappings and tables
-Base = declarative_base()
+class Base(DeclarativeBase):
+    pass
 
 
 # table setup.  we'll store a lead table of continents/cities, and a secondary
@@ -69,13 +70,13 @@ Base = declarative_base()
 class WeatherLocation(Base):
     __tablename__ = "weather_locations"
 
-    id = Column(Integer, primary_key=True)
-    continent = Column(String(30), nullable=False)
-    city = Column(String(50), nullable=False)
+    id: Mapped[int] = mapped_column(primary_key=True)
+    continent: Mapped[str]
+    city: Mapped[str]
 
-    reports = relationship("Report", backref="location")
+    reports: Mapped[list[Report]] = relationship(back_populates="location")
 
-    def __init__(self, continent, city):
+    def __init__(self, continent: str, city: str):
         self.continent = continent
         self.city = city
 
@@ -83,25 +84,22 @@ class WeatherLocation(Base):
 class Report(Base):
     __tablename__ = "weather_reports"
 
-    id = Column(Integer, primary_key=True)
-    location_id = Column(
-        "location_id", Integer, ForeignKey("weather_locations.id")
+    id: Mapped[int] = mapped_column(primary_key=True)
+    location_id: Mapped[int] = mapped_column(
+        ForeignKey("weather_locations.id")
     )
-    temperature = Column("temperature", Float)
-    report_time = Column(
-        "report_time", DateTime, default=datetime.datetime.now
+    temperature: Mapped[float]
+    report_time: Mapped[datetime.datetime] = mapped_column(
+        default=datetime.datetime.now
     )
 
-    def __init__(self, temperature):
-        self.temperature = temperature
-
+    location: Mapped[WeatherLocation] = relationship(back_populates="reports")
 
-# create tables
-for db in (db1, db2, db3, db4):
-    Base.metadata.create_all(db)
+    def __init__(self, temperature: float):
+        self.temperature = temperature
 
 
-# step 5. define sharding functions.
+# define sharding functions.
 
 # we'll use a straight mapping of a particular set of "country"
 # attributes to shard id.
@@ -154,15 +152,11 @@ def execute_chooser(context):
     given an :class:`.ORMExecuteState` for a statement, return a list
     of shards we should consult.
 
-    As before, we want a "shard_id" execution option to be present.
-    Otherwise, this would be a lazy load from a parent object where we
-    will look for the previous token.
-
     """
     if context.lazy_loaded_from:
         return [context.lazy_loaded_from.identity_token]
     else:
-        return [context.execution_options["shard_id"]]
+        return ["north_america", "asia", "europe", "south_america"]
 
 
 # configure shard chooser
@@ -172,70 +166,90 @@ Session.configure(
     execute_chooser=execute_chooser,
 )
 
-# save and load objects!
 
-tokyo = WeatherLocation("Asia", "Tokyo")
-newyork = WeatherLocation("North America", "New York")
-toronto = WeatherLocation("North America", "Toronto")
-london = WeatherLocation("Europe", "London")
-dublin = WeatherLocation("Europe", "Dublin")
-brasilia = WeatherLocation("South America", "Brasila")
-quito = WeatherLocation("South America", "Quito")
+def setup():
+    # create tables
+    for db in (db1, db2, db3, db4):
+        Base.metadata.create_all(db)
 
-tokyo.reports.append(Report(80.0))
-newyork.reports.append(Report(75))
-quito.reports.append(Report(85))
 
-with Session() as sess:
+def main():
+    setup()
 
-    sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
+    # save and load objects!
 
-    sess.commit()
+    tokyo = WeatherLocation("Asia", "Tokyo")
+    newyork = WeatherLocation("North America", "New York")
+    toronto = WeatherLocation("North America", "Toronto")
+    london = WeatherLocation("Europe", "London")
+    dublin = WeatherLocation("Europe", "Dublin")
+    brasilia = WeatherLocation("South America", "Brasila")
+    quito = WeatherLocation("South America", "Quito")
 
-    t = sess.get(
-        WeatherLocation,
-        tokyo.id,
-        # for session.get(), we currently need to use identity_token.
-        # the horizontal sharding API does not yet pass through the
-        # execution options
-        identity_token="asia",
-        # future version
-        # execution_options={"shard_id": "asia"}
-    )
-    assert t.city == tokyo.city
-    assert t.reports[0].temperature == 80.0
-
-    north_american_cities = sess.execute(
-        select(WeatherLocation).filter(
-            WeatherLocation.continent == "North America"
-        ),
-        execution_options={"shard_id": "north_america"},
-    ).scalars()
-
-    assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
-
-    europe = sess.execute(
-        select(WeatherLocation).filter(WeatherLocation.continent == "Europe"),
-        execution_options={"shard_id": "europe"},
-    ).scalars()
-
-    assert {c.city for c in europe} == {"London", "Dublin"}
-
-    # the Report class uses a simple integer primary key.  So across two
-    # databases, a primary key will be repeated.  The "identity_token" tracks
-    # in memory that these two identical primary keys are local to different
-    # databases.
-    newyork_report = newyork.reports[0]
-    tokyo_report = tokyo.reports[0]
-
-    assert inspect(newyork_report).identity_key == (
-        Report,
-        (1,),
-        "north_america",
-    )
-    assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+    tokyo.reports.append(Report(80.0))
+    newyork.reports.append(Report(75))
+    quito.reports.append(Report(85))
+
+    with Session() as sess:
+
+        sess.add_all(
+            [tokyo, newyork, toronto, london, dublin, brasilia, quito]
+        )
+
+        sess.commit()
+
+        t = sess.get(
+            WeatherLocation,
+            tokyo.id,
+            identity_token="asia",
+        )
+        assert t.city == tokyo.city
+        assert t.reports[0].temperature == 80.0
+
+        # select across shards
+        asia_and_europe = sess.execute(
+            select(WeatherLocation).filter(
+                WeatherLocation.continent.in_(["Europe", "Asia"])
+            )
+        ).scalars()
+
+        assert {c.city for c in asia_and_europe} == {
+            "Tokyo",
+            "London",
+            "Dublin",
+        }
+
+        # optionally set a shard id for the query and all related loaders
+        north_american_cities_w_t = sess.execute(
+            select(WeatherLocation)
+            .filter(WeatherLocation.city.startswith("T"))
+            .options(set_shard_id("north_america"))
+        ).scalars()
+
+        # Tokyo not included since not in the north_america shard
+        assert {c.city for c in north_american_cities_w_t} == {
+            "Toronto",
+        }
+
+        # the Report class uses a simple integer primary key.  So across two
+        # databases, a primary key will be repeated.  The "identity_token"
+        # tracks in memory that these two identical primary keys are local to
+        # different shards.
+        newyork_report = newyork.reports[0]
+        tokyo_report = tokyo.reports[0]
+
+        assert inspect(newyork_report).identity_key == (
+            Report,
+            (1,),
+            "north_america",
+        )
+        assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+        # the token representing the originating shard is also available
+        # directly
+        assert inspect(newyork_report).identity_token == "north_america"
+        assert inspect(tokyo_report).identity_token == "asia"
 
-    # the token representing the originating shard is also available directly
 
-    assert inspect(newyork_report).identity_token == "north_america"
-    assert inspect(tokyo_report).identity_token == "asia"
+if __name__ == "__main__":
+    main()
index 97c6a07f6a1047f3c1e742753456ae9e3ab59276..98db3771f7b2897a270eaacff9c6b34e9c429738 100644 (file)
@@ -1,27 +1,27 @@
 """Illustrates sharding using a single SQLite database, that will however
 have multiple tables using a naming convention."""
+from __future__ import annotations
 
 import datetime
 
 from sqlalchemy import Column
 from sqlalchemy import create_engine
-from sqlalchemy import DateTime
 from sqlalchemy import event
-from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import select
-from sqlalchemy import String
 from sqlalchemy import Table
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.horizontal_shard import set_shard_id
 from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import visitors
 
-
 echo = True
 engine = create_engine("sqlite://", echo=echo)
 
@@ -55,7 +55,9 @@ Session = sessionmaker(
 
 
 # mappings and tables
-Base = declarative_base()
+class Base(DeclarativeBase):
+    pass
+
 
 # we need a way to create identifiers which are unique across all databases.
 # one easy way would be to just use a composite primary key, where  one value
@@ -86,13 +88,13 @@ def id_generator(ctx):
 class WeatherLocation(Base):
     __tablename__ = "_prefix__weather_locations"
 
-    id = Column(Integer, primary_key=True, default=id_generator)
-    continent = Column(String(30), nullable=False)
-    city = Column(String(50), nullable=False)
+    id: Mapped[int] = mapped_column(primary_key=True, default=id_generator)
+    continent: Mapped[str]
+    city: Mapped[str]
 
-    reports = relationship("Report", backref="location")
+    reports: Mapped[list[Report]] = relationship(back_populates="location")
 
-    def __init__(self, continent, city):
+    def __init__(self, continent: str, city: str):
         self.continent = continent
         self.city = city
 
@@ -100,29 +102,22 @@ class WeatherLocation(Base):
 class Report(Base):
     __tablename__ = "_prefix__weather_reports"
 
-    id = Column(Integer, primary_key=True)
-    location_id = Column(
-        "location_id", Integer, ForeignKey("_prefix__weather_locations.id")
+    id: Mapped[int] = mapped_column(primary_key=True)
+    location_id: Mapped[int] = mapped_column(
+        ForeignKey("_prefix__weather_locations.id")
     )
-    temperature = Column("temperature", Float)
-    report_time = Column(
-        "report_time", DateTime, default=datetime.datetime.now
+    temperature: Mapped[float]
+    report_time: Mapped[datetime.datetime] = mapped_column(
+        default=datetime.datetime.now
     )
 
-    def __init__(self, temperature):
-        self.temperature = temperature
-
+    location: Mapped[WeatherLocation] = relationship(back_populates="reports")
 
-# create tables
-for db in (db1, db2, db3, db4):
-    Base.metadata.create_all(db)
-
-# establish initial "id" in db1
-with db1.begin() as conn:
-    conn.execute(ids.insert(), {"nextid": 1})
+    def __init__(self, temperature: float):
+        self.temperature = temperature
 
 
-# step 5. define sharding functions.
+# define sharding functions.
 
 # we'll use a straight mapping of a particular set of "country"
 # attributes to shard id.
@@ -255,61 +250,89 @@ Session.configure(
     execute_chooser=execute_chooser,
 )
 
-# save and load objects!
 
-tokyo = WeatherLocation("Asia", "Tokyo")
-newyork = WeatherLocation("North America", "New York")
-toronto = WeatherLocation("North America", "Toronto")
-london = WeatherLocation("Europe", "London")
-dublin = WeatherLocation("Europe", "Dublin")
-brasilia = WeatherLocation("South America", "Brasila")
-quito = WeatherLocation("South America", "Quito")
+def setup():
+    # create tables
+    for db in (db1, db2, db3, db4):
+        Base.metadata.create_all(db)
 
-tokyo.reports.append(Report(80.0))
-newyork.reports.append(Report(75))
-quito.reports.append(Report(85))
+    # establish initial "id" in db1
+    with db1.begin() as conn:
+        conn.execute(ids.insert(), {"nextid": 1})
 
-with Session() as sess:
 
-    sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
+def main():
+    setup()
 
-    sess.commit()
+    # save and load objects!
 
-    t = sess.get(WeatherLocation, tokyo.id)
-    assert t.city == tokyo.city
-    assert t.reports[0].temperature == 80.0
+    tokyo = WeatherLocation("Asia", "Tokyo")
+    newyork = WeatherLocation("North America", "New York")
+    toronto = WeatherLocation("North America", "Toronto")
+    london = WeatherLocation("Europe", "London")
+    dublin = WeatherLocation("Europe", "Dublin")
+    brasilia = WeatherLocation("South America", "Brasila")
+    quito = WeatherLocation("South America", "Quito")
 
-    north_american_cities = sess.execute(
-        select(WeatherLocation).filter(
-            WeatherLocation.continent == "North America"
-        )
-    ).scalars()
+    tokyo.reports.append(Report(80.0))
+    newyork.reports.append(Report(75))
+    quito.reports.append(Report(85))
 
-    assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
+    with Session() as sess:
 
-    asia_and_europe = sess.execute(
-        select(WeatherLocation).filter(
-            WeatherLocation.continent.in_(["Europe", "Asia"])
+        sess.add_all(
+            [tokyo, newyork, toronto, london, dublin, brasilia, quito]
         )
-    ).scalars()
 
-    assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"}
+        sess.commit()
 
-    # the Report class uses a simple integer primary key.  So across two
-    # databases, a primary key will be repeated.  The "identity_token" tracks
-    # in memory that these two identical primary keys are local to different
-    # databases.
-    newyork_report = newyork.reports[0]
-    tokyo_report = tokyo.reports[0]
+        t = sess.get(WeatherLocation, tokyo.id)
+        assert t.city == tokyo.city
+        assert t.reports[0].temperature == 80.0
 
-    assert inspect(newyork_report).identity_key == (
-        Report,
-        (1,),
-        "north_america",
-    )
-    assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+        # optionally set a shard id for the query and all related loaders
+        north_american_cities_w_t = sess.execute(
+            select(WeatherLocation)
+            .filter(WeatherLocation.city.startswith("T"))
+            .options(set_shard_id("north_america"))
+        ).scalars()
+
+        # Tokyo not included since not in the north_america shard
+        assert {c.city for c in north_american_cities_w_t} == {
+            "Toronto",
+        }
+
+        asia_and_europe = sess.execute(
+            select(WeatherLocation).filter(
+                WeatherLocation.continent.in_(["Europe", "Asia"])
+            )
+        ).scalars()
+
+        assert {c.city for c in asia_and_europe} == {
+            "Tokyo",
+            "London",
+            "Dublin",
+        }
+
+        # the Report class uses a simple integer primary key.  So across two
+        # databases, a primary key will be repeated.  The "identity_token"
+        # tracks in memory that these two identical primary keys are local to
+        # different shards.
+        newyork_report = newyork.reports[0]
+        tokyo_report = tokyo.reports[0]
+
+        assert inspect(newyork_report).identity_key == (
+            Report,
+            (1,),
+            "north_america",
+        )
+        assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+        # the token representing the originating shard is also available
+        # directly
+        assert inspect(newyork_report).identity_token == "north_america"
+        assert inspect(tokyo_report).identity_token == "asia"
 
-    # the token representing the originating shard is also available directly
 
-    assert inspect(newyork_report).identity_token == "north_america"
-    assert inspect(tokyo_report).identity_token == "asia"
+if __name__ == "__main__":
+    main()
index df9d0d797b64eb1609b2e5f9a67ba8b3f2bff2ca..0bcc5628fb83de039b2e0f6945698cdb94877e1b 100644 (file)
@@ -42,6 +42,7 @@ from .. import inspect
 from .. import util
 from ..orm import PassiveFlag
 from ..orm._typing import OrmExecuteOptionsParameter
+from ..orm.interfaces import ORMOption
 from ..orm.mapper import Mapper
 from ..orm.query import Query
 from ..orm.session import _BindArguments
@@ -73,7 +74,7 @@ _T = TypeVar("_T", bound=Any)
 
 SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
 
-_ShardKey = str
+ShardIdentifier = str
 
 
 class ShardChooser(Protocol):
@@ -105,8 +106,7 @@ class ShardedQuery(Query[_T]):
 
     .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
        :class:`.Query` class.   The :class:`.ShardedSession` now supports
-       2.0 style execution via the :meth:`.ShardedSession.execute` method
-       as well.
+       2.0 style execution via the :meth:`.ShardedSession.execute` method.
 
     """
 
@@ -118,7 +118,9 @@ class ShardedQuery(Query[_T]):
         self.execute_chooser = self.session.execute_chooser
         self._shard_id = None
 
-    def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
+    def set_shard(
+        self: SelfShardedQuery, shard_id: ShardIdentifier
+    ) -> SelfShardedQuery:
         """Return a new query, limited to a single shard ID.
 
         All subsequent operations with the returned query will
@@ -166,14 +168,12 @@ class ShardedSession(Session):
           should set whatever state on the instance to mark it in the future as
           participating in that shard.
 
-        :param id_chooser: A callable, passed a :class:`.ShardedQuery` and a
-          tuple of identity values, which should return a list of shard ids
-          where the ID might reside. The databases will be queried in the order
-          of this listing.
+        :param identity_chooser: A callable, passed a Mapper and primary key
+         argument, which should return a list of shard ids where this
+         primary key might reside.
 
-          .. legacy:: This parameter still uses the legacy
-             :class:`.ShardedQuery` class as an argument passed to the
-             callable.
+          .. versionchanged:: 2.0  The ``identity_chooser`` parameter
+             supersedes the ``id_chooser`` parameter.
 
         :param execute_chooser: For a given :class:`.ORMExecuteState`,
           returns the list of shard_ids
@@ -250,7 +250,7 @@ class ShardedSession(Session):
                 "execute_chooser or query_chooser is required"
             )
         self.execute_chooser = execute_chooser
-        self.__shards: Dict[_ShardKey, _SessionBind] = {}
+        self.__shards: Dict[ShardIdentifier, _SessionBind] = {}
         if shards is not None:
             for k in shards:
                 self.bind_shard(k, shards[k])
@@ -329,7 +329,7 @@ class ShardedSession(Session):
         self,
         mapper: Optional[Mapper[_T]] = None,
         instance: Optional[Any] = None,
-        shard_id: Optional[Any] = None,
+        shard_id: Optional[ShardIdentifier] = None,
         **kw: Any,
     ) -> Connection:
         """Provide a :class:`_engine.Connection` to use in the unit of work
@@ -359,7 +359,7 @@ class ShardedSession(Session):
         self,
         mapper: Optional[_EntityBindKey[_O]] = None,
         *,
-        shard_id: Optional[_ShardKey] = None,
+        shard_id: Optional[ShardIdentifier] = None,
         instance: Optional[Any] = None,
         clause: Optional[ClauseElement] = None,
         **kw: Any,
@@ -372,11 +372,61 @@ class ShardedSession(Session):
         return self.__shards[shard_id]
 
     def bind_shard(
-        self, shard_id: _ShardKey, bind: Union[Engine, OptionEngine]
+        self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine]
     ) -> None:
         self.__shards[shard_id] = bind
 
 
+class set_shard_id(ORMOption):
+    """a loader option for statements to apply a specific shard id to the
+    primary query as well as for additional relationship and column
+    loaders.
+
+    The :class:`_horizontal.set_shard_id` option may be applied using
+    the :meth:`_sql.Executable.options` method of any executable statement::
+
+        stmt = (
+            select(MyObject).
+            where(MyObject.name == 'some name').
+            options(set_shard_id("shard1"))
+        )
+
+    Above, the statement when invoked will limit to the "shard1" shard
+    identifier for the primary query as well as for all relationship and
+    column loading strategies, including eager loaders such as
+    :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`,
+    and the lazy relationship loader :func:`_orm.lazyload`.
+
+    In this way, the :class:`_horizontal.set_shard_id` option has much wider
+    scope than using the "shard_id" argument within the
+    :paramref:`_orm.Session.execute.bind_arguments` dictionary.
+
+
+    .. versionadded:: 2.0.0
+
+    """
+
+    __slots__ = ("shard_id", "propagate_to_loaders")
+
+    def __init__(
+        self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True
+    ):
+        """Construct a :class:`_horizontal.set_shard_id` option.
+
+        :param shard_id: shard identifier
+        :param propagate_to_loaders: if left at its default of ``True``, the
+         shard option will take place for lazy loaders such as
+         :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option
+         will not be propagated to loaded objects. Note that :func:`_orm.defer`
+         always limits to the shard_id of the parent row in any case, so the
+         parameter only has a net effect on the behavior of the
+         :func:`_orm.lazyload` strategy.
+
+        """
+        self.shard_id = shard_id
+        self.propagate_to_loaders = propagate_to_loaders
+
+
 def execute_and_instances(
     orm_context: ORMExecuteState,
 ) -> Union[Result[_T], IteratorResult[_TP]]:
@@ -400,7 +450,7 @@ def execute_and_instances(
     assert isinstance(session, ShardedSession)
 
     def iter_for_shard(
-        shard_id: str,
+        shard_id: ShardIdentifier,
     ) -> Union[Result[_T], IteratorResult[_TP]]:
 
         bind_arguments = dict(orm_context.bind_arguments)
@@ -409,14 +459,22 @@ def execute_and_instances(
         orm_context.update_execution_options(identity_token=shard_id)
         return orm_context.invoke_statement(bind_arguments=bind_arguments)
 
-    if active_options and active_options._identity_token is not None:
-        shard_id = active_options._identity_token
-    elif "_sa_shard_id" in orm_context.execution_options:
-        shard_id = orm_context.execution_options["_sa_shard_id"]
-    elif "shard_id" in orm_context.bind_arguments:
-        shard_id = orm_context.bind_arguments["shard_id"]
+    for orm_opt in orm_context._non_compile_orm_options:
+        # TODO: if we had an ORMOption that gets applied at ORM statement
+        # execution time, that would allow this to be more generalized.
+        # for now just iterate and look for our options
+        if isinstance(orm_opt, set_shard_id):
+            shard_id = orm_opt.shard_id
+            break
     else:
-        shard_id = None
+        if active_options and active_options._identity_token is not None:
+            shard_id = active_options._identity_token
+        elif "_sa_shard_id" in orm_context.execution_options:
+            shard_id = orm_context.execution_options["_sa_shard_id"]
+        elif "shard_id" in orm_context.bind_arguments:
+            shard_id = orm_context.bind_arguments["shard_id"]
+        else:
+            shard_id = None
 
     if shard_id is not None:
         return iter_for_shard(shard_id)
index 76483265d4916122ad4051537d9f3174480f4f87..36dc6ddb9cecc5f1bb251dc059767145e76c9a7f 100644 (file)
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
     from .decl_api import registry as _registry_type
     from .interfaces import InspectionAttr
     from .interfaces import MapperProperty
+    from .interfaces import ORMOption
     from .interfaces import UserDefinedOption
     from .mapper import Mapper
     from .relationships import RelationshipProperty
@@ -122,6 +123,12 @@ class _LoaderCallable(Protocol):
         ...
 
 
+def is_orm_option(
+    opt: ExecutableOption,
+) -> TypeGuard[ORMOption]:
+    return not opt._is_core  # type: ignore
+
+
 def is_user_defined_option(
     opt: ExecutableOption,
 ) -> TypeGuard[UserDefinedOption]:
index 5bcb22a083b298dea085885c976ac893a8605ae5..babd5d72adf9924090925b139232d5f41af01726 100644 (file)
@@ -46,6 +46,7 @@ from . import state as statelib
 from ._typing import _O
 from ._typing import insp_is_mapper
 from ._typing import is_composite_class
+from ._typing import is_orm_option
 from ._typing import is_user_defined_option
 from .base import _class_to_mapper
 from .base import _none_set
@@ -730,6 +731,14 @@ class ORMExecuteState(util.MemoizedSlots):
             bulk_persistence.BulkUDCompileState.default_update_options,
         )
 
+    @property
+    def _non_compile_orm_options(self) -> Sequence[ORMOption]:
+        return [
+            opt
+            for opt in self.statement._with_options
+            if is_orm_option(opt) and not opt._is_compile_state
+        ]
+
     @property
     def user_defined_options(self) -> Sequence[UserDefinedOption]:
         """The sequence of :class:`.UserDefinedOptions` that have been
index 5dd32e8cc4d0fe064f1b8b2e42d6b3fe73011e3f..c66ba71c3f132ab270741226fc96095dc3c48dc6 100644 (file)
@@ -906,7 +906,17 @@ class AssertsExecutionResults:
             db, callable_, assertsql.CountStatements(count)
         )
 
-    def assert_multiple_sql_count(self, dbs, callable_, counts):
+    @contextlib.contextmanager
+    def assert_execution(self, db, *rules):
+        with self.sql_execution_asserter(db) as asserter:
+            yield
+        asserter.assert_(*rules)
+
+    def assert_statement_count(self, db, count):
+        return self.assert_execution(db, assertsql.CountStatements(count))
+
+    @contextlib.contextmanager
+    def assert_statement_count_multi_db(self, dbs, counts):
         recs = [
             (self.sql_execution_asserter(db), db, count)
             for (db, count) in zip(dbs, counts)
@@ -915,21 +925,12 @@ class AssertsExecutionResults:
         for ctx, db, count in recs:
             asserters.append(ctx.__enter__())
         try:
-            return callable_()
+            yield
         finally:
             for asserter, (ctx, db, count) in zip(asserters, recs):
                 ctx.__exit__(None, None, None)
                 asserter.assert_(assertsql.CountStatements(count))
 
-    @contextlib.contextmanager
-    def assert_execution(self, db, *rules):
-        with self.sql_execution_asserter(db) as asserter:
-            yield
-        asserter.assert_(*rules)
-
-    def assert_statement_count(self, db, count):
-        return self.assert_execution(db, assertsql.CountStatements(count))
-
 
 class ComparesIndexes:
     def compare_table_index_with_expected(
index 8e5d09cab0a6487854f3d19d893b3b672da1d74f..467c61f160150f79daaf4f8b39dc77cc632aac5e 100644 (file)
@@ -17,9 +17,12 @@ from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import update
 from sqlalchemy import util
+from sqlalchemy.ext.horizontal_shard import set_shard_id
 from sqlalchemy.ext.horizontal_shard import ShardedSession
 from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import defer
 from sqlalchemy.orm import deferred
+from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
@@ -83,7 +86,7 @@ class ShardTest:
 
     def setup_test(self):
         global db1, db2, db3, db4
-        db1, db2, db3, db4 = self._dbs = self._init_dbs()
+        db1, db2, db3, db4 = self._dbs = self.dbs = self._init_dbs()
 
         for db in (db1, db2, db3, db4):
             self.tables_test_metadata.create_all(db)
@@ -245,6 +248,103 @@ class ShardTest:
         t2 = sess.get(WeatherLocation, 1)
         eq_(t2.city, "Tokyo")
 
+    @testing.variation("option_type", ["none", "lazyload", "selectinload"])
+    @testing.variation(
+        "limit_shard",
+        ["none", "lead_only", "propagate_to_loaders", "bind_arg"],
+    )
+    def test_set_shard_option_relationship(self, option_type, limit_shard):
+        sess = self._fixture_data()
+
+        stmt = select(WeatherLocation).filter(
+            WeatherLocation.city == "New York"
+        )
+
+        bind_arguments = {}
+
+        if limit_shard.none:
+            # right now selectinload / lazyload runs all the shards even if the
+            # ids are limited to just one shard, since that information
+            # is not transferred
+            counts = [2, 2, 2, 2]
+        elif limit_shard.lead_only:
+            if option_type.selectinload:
+                counts = [2, 0, 0, 0]
+            else:
+                counts = [2, 1, 1, 1]
+        elif limit_shard.bind_arg:
+            counts = [2, 1, 1, 1]
+        elif limit_shard.propagate_to_loaders:
+            counts = [2, 0, 0, 0]
+        else:
+            limit_shard.fail()
+
+        if option_type.lazyload:
+            stmt = stmt.options(lazyload(WeatherLocation.reports))
+        elif option_type.selectinload:
+            stmt = stmt.options(selectinload(WeatherLocation.reports))
+
+        if limit_shard.lead_only:
+            stmt = stmt.options(
+                set_shard_id("north_america", propagate_to_loaders=False)
+            )
+        elif limit_shard.propagate_to_loaders:
+            stmt = stmt.options(set_shard_id("north_america"))
+        elif limit_shard.bind_arg:
+            bind_arguments["shard_id"] = "north_america"
+
+        with self.assert_statement_count_multi_db(self.dbs, counts):
+            w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first()
+            w1.reports
+
+    @testing.variation("option_type", ["none", "defer"])
+    @testing.variation(
+        "limit_shard",
+        ["none", "lead_only", "propagate_to_loaders", "bind_arg"],
+    )
+    def test_set_shard_option_column(self, option_type, limit_shard):
+        sess = self._fixture_data()
+
+        stmt = select(WeatherLocation).filter(
+            WeatherLocation.city == "New York"
+        )
+
+        bind_arguments = {}
+
+        if limit_shard.none:
+            if option_type.defer:
+                counts = [2, 1, 1, 1]
+            else:
+                counts = [1, 1, 1, 1]
+        elif limit_shard.lead_only or limit_shard.propagate_to_loaders:
+            if option_type.defer:
+                counts = [2, 0, 0, 0]
+            else:
+                counts = [1, 0, 0, 0]
+        elif limit_shard.bind_arg:
+            if option_type.defer:
+                counts = [2, 0, 0, 0]
+            else:
+                counts = [1, 0, 0, 0]
+        else:
+            limit_shard.fail()
+
+        if option_type.defer:
+            stmt = stmt.options(defer(WeatherLocation.continent))
+
+        if limit_shard.lead_only:
+            stmt = stmt.options(
+                set_shard_id("north_america", propagate_to_loaders=False)
+            )
+        elif limit_shard.propagate_to_loaders:
+            stmt = stmt.options(set_shard_id("north_america"))
+        elif limit_shard.bind_arg:
+            bind_arguments["shard_id"] = "north_america"
+
+        with self.assert_statement_count_multi_db(self.dbs, counts):
+            w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first()
+            w1.continent
+
     def test_query_explicit_shard_via_bind_opts(self):
         sess = self._fixture_data()
 
@@ -502,14 +602,9 @@ class ShardTest:
                 .execution_options(synchronize_session=synchronize_session)
             )
 
-        # test synchronize session
-        def go():
+        with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]):
             eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0})
 
-        self.assert_sql_count(
-            sess._ShardedSession__shards["north_america"], go, 0
-        )
-
         eq_(
             {row.temperature for row in sess.query(Report.temperature)},
             {86.0, 75.0, 91.0},
@@ -536,15 +631,11 @@ class ShardTest:
                 .execution_options(synchronize_session=synchronize_session)
             )
 
-        def go():
+        with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]):
             # test synchronize session
             for t in temps:
                 assert inspect(t).deleted is (t.temperature >= 80)
 
-        self.assert_sql_count(
-            sess._ShardedSession__shards["north_america"], go, 0
-        )
-
         eq_(
             {row.temperature for row in sess.query(Report.temperature)},
             {75.0},
@@ -704,16 +795,15 @@ class TableNameConventionShardTest(ShardTest, fixtures.MappedTest):
     schema = "changeme"
 
     def _init_dbs(self):
-        db1 = testing_engine(
-            "sqlite://", options={"execution_options": {"shard_id": "shard1"}}
-        )
-        db2 = db1.execution_options(shard_id="shard2")
-        db3 = db1.execution_options(shard_id="shard3")
-        db4 = db1.execution_options(shard_id="shard4")
+        dbmain = testing_engine("sqlite://")
+        db1 = dbmain.execution_options(shard_id="shard1")
+        db2 = dbmain.execution_options(shard_id="shard2")
+        db3 = dbmain.execution_options(shard_id="shard3")
+        db4 = dbmain.execution_options(shard_id="shard4")
 
         import re
 
-        @event.listens_for(db1, "before_cursor_execute", retval=True)
+        @event.listens_for(dbmain, "before_cursor_execute", retval=True)
         def _switch_shard(conn, cursor, stmt, params, context, executemany):
             shard_id = conn._execution_options["shard_id"]
             # because SQLite can't just give us a "use" statement, we have
@@ -977,12 +1067,10 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
 
         session.expire(page, ["book"])
 
-        def go():
+        with self.assert_statement_count_multi_db(self.dbs, [0, 0]):
+            # doesn't emit SQL
             eq_(page.book, book)
 
-        # doesn't emit SQL
-        self.assert_multiple_sql_count(self.dbs, go, [0, 0])
-
     def test_lazy_load_from_db(self):
         session = self._fixture(lazy_load_book=True)
 
@@ -999,12 +1087,10 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
         book1_page = session.query(Page).first()
         session.expire(book1_page, ["book"])
 
-        def go():
+        with self.assert_statement_count_multi_db(self.dbs, [1, 0]):
+            # emits one query
             eq_(inspect(book1_page.book).identity_key, book1_id)
 
-        # emits one query
-        self.assert_multiple_sql_count(self.dbs, go, [1, 0])
-
     def test_lazy_load_no_baked_conflict(self):
         session = self._fixture(lazy_load_pages=True)