--- /dev/null
+.. change::
+ :tags: feature, orm extensions
+ :tickets: 7226
+
+ Added new option to horizontal sharding API
+ :class:`_horizontal.set_shard_id` which sets the effective shard identifier
+ to query against, for both the primary query as well as for all secondary
+ loaders including relationship eager loaders as well as relationship and
+ column lazy loaders.
.. autoclass:: ShardedSession
:members:
+.. autoclass:: set_shard_id
+ :members:
+
.. autoclass:: ShardedQuery
:members:
--- /dev/null
+"""Illustrates sharding API used with asyncio.
+
+For the sync version of this example, see separate_databases.py.
+
+Most of the code here is copied from separate_databases.py and works
+in exactly the same way. The main change is how the
+``async_sessionmaker`` is configured, and as is specific to this example
+the routine that generates new primary keys.
+
+"""
+from __future__ import annotations
+
+import asyncio
+import datetime
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.horizontal_shard import set_shard_id
+from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import immediateload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.sql import operators
+from sqlalchemy.sql import visitors
+
+
+echo = True
+db1 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+db2 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+db3 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+db4 = create_async_engine("sqlite+aiosqlite://", echo=echo)
+
+
+# for asyncio, the ShardedSession class is passed
+# via sync_session_class. The shards themselves are used within
+# implicit-awaited internals, so we use the sync_engine Engine objects
+# in the shards dictionary.
+Session = async_sessionmaker(
+ sync_session_class=ShardedSession,
+ expire_on_commit=False,
+ shards={
+ "north_america": db1.sync_engine,
+ "asia": db2.sync_engine,
+ "europe": db3.sync_engine,
+ "south_america": db4.sync_engine,
+ },
+)
+
+
+# mappings and tables
+class Base(DeclarativeBase):
+ pass
+
+
+# we need a way to create identifiers which are unique across all databases.
+# one easy way would be to just use a composite primary key, where one value
+# is the shard id. but here, we'll show something more "generic", an id
+# generation function. we'll use a simplistic "id table" stored in database
+# #1. Any other method will do just as well; UUID, hilo, application-specific,
+# etc.
+
+ids = Table("ids", Base.metadata, Column("nextid", Integer, nullable=False))
+
+
+def id_generator(ctx):
+ # id_generator is run within a "synchronous" context, where
+ # we use an implicit-await API that will convert back to explicit await
+ # calls when it reaches the driver.
+ with db1.sync_engine.begin() as conn:
+ nextid = conn.scalar(ids.select().with_for_update())
+ conn.execute(ids.update().values({ids.c.nextid: ids.c.nextid + 1}))
+ return nextid
+
+
+# table setup. we'll store a lead table of continents/cities, and a secondary
+# table storing locations. a particular row will be placed in the database
+# whose shard id corresponds to the 'continent'. in this setup, secondary rows
+# in 'weather_reports' will be placed in the same DB as that of the parent, but
+# this can be changed if you're willing to write more complex sharding
+# functions.
+
+
+class WeatherLocation(Base):
+ __tablename__ = "weather_locations"
+
+ id: Mapped[int] = mapped_column(primary_key=True, default=id_generator)
+ continent: Mapped[str]
+ city: Mapped[str]
+
+ reports: Mapped[list[Report]] = relationship(back_populates="location")
+
+ def __init__(self, continent: str, city: str):
+ self.continent = continent
+ self.city = city
+
+
+class Report(Base):
+ __tablename__ = "weather_reports"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ location_id: Mapped[int] = mapped_column(
+ ForeignKey("weather_locations.id")
+ )
+ temperature: Mapped[float]
+ report_time: Mapped[datetime.datetime] = mapped_column(
+ default=datetime.datetime.now
+ )
+
+ location: Mapped[WeatherLocation] = relationship(back_populates="reports")
+
+ def __init__(self, temperature: float):
+ self.temperature = temperature
+
+
+# step 5. define sharding functions.
+
+# we'll use a straight mapping of a particular set of "country"
+# attributes to shard id.
+shard_lookup = {
+ "North America": "north_america",
+ "Asia": "asia",
+ "Europe": "europe",
+ "South America": "south_america",
+}
+
+
+def shard_chooser(mapper, instance, clause=None):
+ """shard chooser.
+
+ looks at the given instance and returns a shard id
+ note that we need to define conditions for
+ the WeatherLocation class, as well as our secondary Report class which will
+ point back to its WeatherLocation via its 'location' attribute.
+
+ """
+ if isinstance(instance, WeatherLocation):
+ return shard_lookup[instance.continent]
+ else:
+ return shard_chooser(mapper, instance.location)
+
+
+def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw):
+ """identity chooser.
+
+ given a primary key, returns a list of shards
+ to search. here, we don't have any particular information from a
+ pk so we just return all shard ids. often, you'd want to do some
+ kind of round-robin strategy here so that requests are evenly
+ distributed among DBs.
+
+ """
+ if lazy_loaded_from:
+ # if we are in a lazy load, we can look at the parent object
+ # and limit our search to that same shard, assuming that's how we've
+ # set things up.
+ return [lazy_loaded_from.identity_token]
+ else:
+ return ["north_america", "asia", "europe", "south_america"]
+
+
+def execute_chooser(context):
+ """statement execution chooser.
+
+ this also returns a list of shard ids, which can just be all of them. but
+ here we'll search into the execution context in order to try to narrow down
+ the list of shards to SELECT.
+
+ """
+ ids = []
+
+ # we'll grab continent names as we find them
+ # and convert to shard ids
+ for column, operator, value in _get_select_comparisons(context.statement):
+ # "shares_lineage()" returns True if both columns refer to the same
+ # statement column, adjusting for any annotations present.
+ # (an annotation is an internal clone of a Column object
+ # and occur when using ORM-mapped attributes like
+ # "WeatherLocation.continent"). A simpler comparison, though less
+ # accurate, would be "column.key == 'continent'".
+ if column.shares_lineage(WeatherLocation.__table__.c.continent):
+ if operator == operators.eq:
+ ids.append(shard_lookup[value])
+ elif operator == operators.in_op:
+ ids.extend(shard_lookup[v] for v in value)
+
+ if len(ids) == 0:
+ return ["north_america", "asia", "europe", "south_america"]
+ else:
+ return ids
+
+
+def _get_select_comparisons(statement):
+ """Search a Select or Query object for binary expressions.
+
+ Returns expressions which match a Column against one or more
+ literal values as a list of tuples of the form
+ (column, operator, values). "values" is a single value
+ or tuple of values depending on the operator.
+
+ """
+ binds = {}
+ clauses = set()
+ comparisons = []
+
+ def visit_bindparam(bind):
+ # visit a bind parameter.
+
+ value = bind.effective_value
+ binds[bind] = value
+
+ def visit_column(column):
+ clauses.add(column)
+
+ def visit_binary(binary):
+ if binary.left in clauses and binary.right in binds:
+ comparisons.append(
+ (binary.left, binary.operator, binds[binary.right])
+ )
+
+ elif binary.left in binds and binary.right in clauses:
+ comparisons.append(
+ (binary.right, binary.operator, binds[binary.left])
+ )
+
+ # here we will traverse through the query's criterion, searching
+ # for SQL constructs. We will place simple column comparisons
+ # into a list.
+ if statement.whereclause is not None:
+ visitors.traverse(
+ statement.whereclause,
+ {},
+ {
+ "bindparam": visit_bindparam,
+ "binary": visit_binary,
+ "column": visit_column,
+ },
+ )
+ return comparisons
+
+
+# further configure create_session to use these functions
+Session.configure(
+ shard_chooser=shard_chooser,
+ identity_chooser=identity_chooser,
+ execute_chooser=execute_chooser,
+)
+
+
+async def setup():
+ # create tables
+ for db in (db1, db2, db3, db4):
+ async with db.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+
+ # establish initial "id" in db1
+ async with db1.begin() as conn:
+ await conn.execute(ids.insert(), {"nextid": 1})
+
+
+async def main():
+ await setup()
+
+ # save and load objects!
+
+ tokyo = WeatherLocation("Asia", "Tokyo")
+ newyork = WeatherLocation("North America", "New York")
+ toronto = WeatherLocation("North America", "Toronto")
+ london = WeatherLocation("Europe", "London")
+ dublin = WeatherLocation("Europe", "Dublin")
+ brasilia = WeatherLocation("South America", "Brasila")
+ quito = WeatherLocation("South America", "Quito")
+
+ tokyo.reports.append(Report(80.0))
+ newyork.reports.append(Report(75))
+ quito.reports.append(Report(85))
+
+ async with Session() as sess:
+
+ sess.add_all(
+ [tokyo, newyork, toronto, london, dublin, brasilia, quito]
+ )
+
+ await sess.commit()
+
+ t = await sess.get(
+ WeatherLocation,
+ tokyo.id,
+ options=[immediateload(WeatherLocation.reports)],
+ )
+ assert t.city == tokyo.city
+ assert t.reports[0].temperature == 80.0
+
+ # select across shards
+ asia_and_europe = (
+ await sess.execute(
+ select(WeatherLocation).filter(
+ WeatherLocation.continent.in_(["Europe", "Asia"])
+ )
+ )
+ ).scalars()
+
+ assert {c.city for c in asia_and_europe} == {
+ "Tokyo",
+ "London",
+ "Dublin",
+ }
+
+ # optionally set a shard id for the query and all related loaders
+ north_american_cities_w_t = (
+ await sess.execute(
+ select(WeatherLocation)
+ .filter(WeatherLocation.city.startswith("T"))
+ .options(set_shard_id("north_america"))
+ )
+ ).scalars()
+
+ # Tokyo not included since not in the north_america shard
+ assert {c.city for c in north_american_cities_w_t} == {
+ "Toronto",
+ }
+
+ # the Report class uses a simple integer primary key. So across two
+ # databases, a primary key will be repeated. The "identity_token"
+ # tracks in memory that these two identical primary keys are local to
+ # different shards.
+ newyork_report = newyork.reports[0]
+ tokyo_report = tokyo.reports[0]
+
+ assert inspect(newyork_report).identity_key == (
+ Report,
+ (1,),
+ "north_america",
+ )
+ assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+ # the token representing the originating shard is also available
+ # directly
+ assert inspect(newyork_report).identity_token == "north_america"
+ assert inspect(tokyo_report).identity_token == "asia"
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
"""Illustrates sharding using distinct SQLite databases."""
+from __future__ import annotations
import datetime
from sqlalchemy import Column
from sqlalchemy import create_engine
-from sqlalchemy import DateTime
-from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import select
-from sqlalchemy import String
from sqlalchemy import Table
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.horizontal_shard import set_shard_id
from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import operators
# mappings and tables
-Base = declarative_base()
+class Base(DeclarativeBase):
+ pass
+
# we need a way to create identifiers which are unique across all databases.
# one easy way would be to just use a composite primary key, where one value
class WeatherLocation(Base):
__tablename__ = "weather_locations"
- id = Column(Integer, primary_key=True, default=id_generator)
- continent = Column(String(30), nullable=False)
- city = Column(String(50), nullable=False)
+ id: Mapped[int] = mapped_column(primary_key=True, default=id_generator)
+ continent: Mapped[str]
+ city: Mapped[str]
- reports = relationship("Report", backref="location")
+ reports: Mapped[list[Report]] = relationship(back_populates="location")
- def __init__(self, continent, city):
+ def __init__(self, continent: str, city: str):
self.continent = continent
self.city = city
class Report(Base):
__tablename__ = "weather_reports"
- id = Column(Integer, primary_key=True)
- location_id = Column(
- "location_id", Integer, ForeignKey("weather_locations.id")
+ id: Mapped[int] = mapped_column(primary_key=True)
+ location_id: Mapped[int] = mapped_column(
+ ForeignKey("weather_locations.id")
)
- temperature = Column("temperature", Float)
- report_time = Column(
- "report_time", DateTime, default=datetime.datetime.now
+ temperature: Mapped[float]
+ report_time: Mapped[datetime.datetime] = mapped_column(
+ default=datetime.datetime.now
)
- def __init__(self, temperature):
- self.temperature = temperature
-
+ location: Mapped[WeatherLocation] = relationship(back_populates="reports")
-# create tables
-for db in (db1, db2, db3, db4):
- Base.metadata.create_all(db)
-
-# establish initial "id" in db1
-with db1.begin() as conn:
- conn.execute(ids.insert(), {"nextid": 1})
+ def __init__(self, temperature: float):
+ self.temperature = temperature
-# step 5. define sharding functions.
+# define sharding functions.
# we'll use a straight mapping of a particular set of "country"
# attributes to shard id.
execute_chooser=execute_chooser,
)
-# save and load objects!
-tokyo = WeatherLocation("Asia", "Tokyo")
-newyork = WeatherLocation("North America", "New York")
-toronto = WeatherLocation("North America", "Toronto")
-london = WeatherLocation("Europe", "London")
-dublin = WeatherLocation("Europe", "Dublin")
-brasilia = WeatherLocation("South America", "Brasila")
-quito = WeatherLocation("South America", "Quito")
+def setup():
+ # create tables
+ for db in (db1, db2, db3, db4):
+ Base.metadata.create_all(db)
-tokyo.reports.append(Report(80.0))
-newyork.reports.append(Report(75))
-quito.reports.append(Report(85))
+ # establish initial "id" in db1
+ with db1.begin() as conn:
+ conn.execute(ids.insert(), {"nextid": 1})
-with Session() as sess:
- sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
+def main():
+ setup()
- sess.commit()
+ # save and load objects!
- t = sess.get(WeatherLocation, tokyo.id)
- assert t.city == tokyo.city
- assert t.reports[0].temperature == 80.0
+ tokyo = WeatherLocation("Asia", "Tokyo")
+ newyork = WeatherLocation("North America", "New York")
+ toronto = WeatherLocation("North America", "Toronto")
+ london = WeatherLocation("Europe", "London")
+ dublin = WeatherLocation("Europe", "Dublin")
+ brasilia = WeatherLocation("South America", "Brasila")
+ quito = WeatherLocation("South America", "Quito")
- north_american_cities = sess.execute(
- select(WeatherLocation).filter(
- WeatherLocation.continent == "North America"
- )
- ).scalars()
+ tokyo.reports.append(Report(80.0))
+ newyork.reports.append(Report(75))
+ quito.reports.append(Report(85))
- assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
+ with Session() as sess:
- asia_and_europe = sess.execute(
- select(WeatherLocation).filter(
- WeatherLocation.continent.in_(["Europe", "Asia"])
+ sess.add_all(
+ [tokyo, newyork, toronto, london, dublin, brasilia, quito]
)
- ).scalars()
- assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"}
+ sess.commit()
- # the Report class uses a simple integer primary key. So across two
- # databases, a primary key will be repeated. The "identity_token" tracks
- # in memory that these two identical primary keys are local to different
- # databases.
- newyork_report = newyork.reports[0]
- tokyo_report = tokyo.reports[0]
+ t = sess.get(WeatherLocation, tokyo.id)
+ assert t.city == tokyo.city
+ assert t.reports[0].temperature == 80.0
- assert inspect(newyork_report).identity_key == (
- Report,
- (1,),
- "north_america",
- )
- assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+ # select across shards
+ asia_and_europe = sess.execute(
+ select(WeatherLocation).filter(
+ WeatherLocation.continent.in_(["Europe", "Asia"])
+ )
+ ).scalars()
+
+ assert {c.city for c in asia_and_europe} == {
+ "Tokyo",
+ "London",
+ "Dublin",
+ }
+
+ # optionally set a shard id for the query and all related loaders
+ north_american_cities_w_t = sess.execute(
+ select(WeatherLocation)
+ .filter(WeatherLocation.city.startswith("T"))
+ .options(set_shard_id("north_america"))
+ ).scalars()
+
+ # Tokyo not included since not in the north_america shard
+ assert {c.city for c in north_american_cities_w_t} == {
+ "Toronto",
+ }
+
+ # the Report class uses a simple integer primary key. So across two
+ # databases, a primary key will be repeated. The "identity_token"
+ # tracks in memory that these two identical primary keys are local to
+ # different shards.
+ newyork_report = newyork.reports[0]
+ tokyo_report = tokyo.reports[0]
+
+ assert inspect(newyork_report).identity_key == (
+ Report,
+ (1,),
+ "north_america",
+ )
+ assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+ # the token representing the originating shard is also available
+ # directly
+ assert inspect(newyork_report).identity_token == "north_america"
+ assert inspect(tokyo_report).identity_token == "asia"
- # the token representing the originating shard is also available directly
- assert inspect(newyork_report).identity_token == "north_america"
- assert inspect(tokyo_report).identity_token == "asia"
+if __name__ == "__main__":
+ main()
In this example we will set a "shard id" at all times.
"""
+from __future__ import annotations
+
import datetime
import os
-from sqlalchemy import Column
from sqlalchemy import create_engine
-from sqlalchemy import DateTime
-from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
-from sqlalchemy import Integer
from sqlalchemy import select
-from sqlalchemy import String
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.horizontal_shard import set_shard_id
from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker
# mappings and tables
-Base = declarative_base()
+class Base(DeclarativeBase):
+ pass
# table setup. we'll store a lead table of continents/cities, and a secondary
class WeatherLocation(Base):
__tablename__ = "weather_locations"
- id = Column(Integer, primary_key=True)
- continent = Column(String(30), nullable=False)
- city = Column(String(50), nullable=False)
+ id: Mapped[int] = mapped_column(primary_key=True)
+ continent: Mapped[str]
+ city: Mapped[str]
- reports = relationship("Report", backref="location")
+ reports: Mapped[list[Report]] = relationship(back_populates="location")
- def __init__(self, continent, city):
+ def __init__(self, continent: str, city: str):
self.continent = continent
self.city = city
class Report(Base):
__tablename__ = "weather_reports"
- id = Column(Integer, primary_key=True)
- location_id = Column(
- "location_id", Integer, ForeignKey("weather_locations.id")
+ id: Mapped[int] = mapped_column(primary_key=True)
+ location_id: Mapped[int] = mapped_column(
+ ForeignKey("weather_locations.id")
)
- temperature = Column("temperature", Float)
- report_time = Column(
- "report_time", DateTime, default=datetime.datetime.now
+ temperature: Mapped[float]
+ report_time: Mapped[datetime.datetime] = mapped_column(
+ default=datetime.datetime.now
)
- def __init__(self, temperature):
- self.temperature = temperature
-
+ location: Mapped[WeatherLocation] = relationship(back_populates="reports")
-# create tables
-for db in (db1, db2, db3, db4):
- Base.metadata.create_all(db)
+ def __init__(self, temperature: float):
+ self.temperature = temperature
-# step 5. define sharding functions.
+# define sharding functions.
# we'll use a straight mapping of a particular set of "country"
# attributes to shard id.
given an :class:`.ORMExecuteState` for a statement, return a list
of shards we should consult.
- As before, we want a "shard_id" execution option to be present.
- Otherwise, this would be a lazy load from a parent object where we
- will look for the previous token.
-
"""
if context.lazy_loaded_from:
return [context.lazy_loaded_from.identity_token]
else:
- return [context.execution_options["shard_id"]]
+ return ["north_america", "asia", "europe", "south_america"]
# configure shard chooser
execute_chooser=execute_chooser,
)
-# save and load objects!
-tokyo = WeatherLocation("Asia", "Tokyo")
-newyork = WeatherLocation("North America", "New York")
-toronto = WeatherLocation("North America", "Toronto")
-london = WeatherLocation("Europe", "London")
-dublin = WeatherLocation("Europe", "Dublin")
-brasilia = WeatherLocation("South America", "Brasila")
-quito = WeatherLocation("South America", "Quito")
+def setup():
+ # create tables
+ for db in (db1, db2, db3, db4):
+ Base.metadata.create_all(db)
-tokyo.reports.append(Report(80.0))
-newyork.reports.append(Report(75))
-quito.reports.append(Report(85))
-with Session() as sess:
+def main():
+ setup()
- sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
+ # save and load objects!
- sess.commit()
+ tokyo = WeatherLocation("Asia", "Tokyo")
+ newyork = WeatherLocation("North America", "New York")
+ toronto = WeatherLocation("North America", "Toronto")
+ london = WeatherLocation("Europe", "London")
+ dublin = WeatherLocation("Europe", "Dublin")
+ brasilia = WeatherLocation("South America", "Brasila")
+ quito = WeatherLocation("South America", "Quito")
- t = sess.get(
- WeatherLocation,
- tokyo.id,
- # for session.get(), we currently need to use identity_token.
- # the horizontal sharding API does not yet pass through the
- # execution options
- identity_token="asia",
- # future version
- # execution_options={"shard_id": "asia"}
- )
- assert t.city == tokyo.city
- assert t.reports[0].temperature == 80.0
-
- north_american_cities = sess.execute(
- select(WeatherLocation).filter(
- WeatherLocation.continent == "North America"
- ),
- execution_options={"shard_id": "north_america"},
- ).scalars()
-
- assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
-
- europe = sess.execute(
- select(WeatherLocation).filter(WeatherLocation.continent == "Europe"),
- execution_options={"shard_id": "europe"},
- ).scalars()
-
- assert {c.city for c in europe} == {"London", "Dublin"}
-
- # the Report class uses a simple integer primary key. So across two
- # databases, a primary key will be repeated. The "identity_token" tracks
- # in memory that these two identical primary keys are local to different
- # databases.
- newyork_report = newyork.reports[0]
- tokyo_report = tokyo.reports[0]
-
- assert inspect(newyork_report).identity_key == (
- Report,
- (1,),
- "north_america",
- )
- assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+ tokyo.reports.append(Report(80.0))
+ newyork.reports.append(Report(75))
+ quito.reports.append(Report(85))
+
+ with Session() as sess:
+
+ sess.add_all(
+ [tokyo, newyork, toronto, london, dublin, brasilia, quito]
+ )
+
+ sess.commit()
+
+ t = sess.get(
+ WeatherLocation,
+ tokyo.id,
+ identity_token="asia",
+ )
+ assert t.city == tokyo.city
+ assert t.reports[0].temperature == 80.0
+
+ # select across shards
+ asia_and_europe = sess.execute(
+ select(WeatherLocation).filter(
+ WeatherLocation.continent.in_(["Europe", "Asia"])
+ )
+ ).scalars()
+
+ assert {c.city for c in asia_and_europe} == {
+ "Tokyo",
+ "London",
+ "Dublin",
+ }
+
+ # optionally set a shard id for the query and all related loaders
+ north_american_cities_w_t = sess.execute(
+ select(WeatherLocation)
+ .filter(WeatherLocation.city.startswith("T"))
+ .options(set_shard_id("north_america"))
+ ).scalars()
+
+ # Tokyo not included since not in the north_america shard
+ assert {c.city for c in north_american_cities_w_t} == {
+ "Toronto",
+ }
+
+ # the Report class uses a simple integer primary key. So across two
+ # databases, a primary key will be repeated. The "identity_token"
+ # tracks in memory that these two identical primary keys are local to
+ # different shards.
+ newyork_report = newyork.reports[0]
+ tokyo_report = tokyo.reports[0]
+
+ assert inspect(newyork_report).identity_key == (
+ Report,
+ (1,),
+ "north_america",
+ )
+ assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+ # the token representing the originating shard is also available
+ # directly
+ assert inspect(newyork_report).identity_token == "north_america"
+ assert inspect(tokyo_report).identity_token == "asia"
- # the token representing the originating shard is also available directly
- assert inspect(newyork_report).identity_token == "north_america"
- assert inspect(tokyo_report).identity_token == "asia"
+if __name__ == "__main__":
+ main()
"""Illustrates sharding using a single SQLite database, that will however
have multiple tables using a naming convention."""
+from __future__ import annotations
import datetime
from sqlalchemy import Column
from sqlalchemy import create_engine
-from sqlalchemy import DateTime
from sqlalchemy import event
-from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import select
-from sqlalchemy import String
from sqlalchemy import Table
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.horizontal_shard import set_shard_id
from sqlalchemy.ext.horizontal_shard import ShardedSession
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import operators
from sqlalchemy.sql import visitors
-
echo = True
engine = create_engine("sqlite://", echo=echo)
# mappings and tables
-Base = declarative_base()
+class Base(DeclarativeBase):
+ pass
+
# we need a way to create identifiers which are unique across all databases.
# one easy way would be to just use a composite primary key, where one value
class WeatherLocation(Base):
__tablename__ = "_prefix__weather_locations"
- id = Column(Integer, primary_key=True, default=id_generator)
- continent = Column(String(30), nullable=False)
- city = Column(String(50), nullable=False)
+ id: Mapped[int] = mapped_column(primary_key=True, default=id_generator)
+ continent: Mapped[str]
+ city: Mapped[str]
- reports = relationship("Report", backref="location")
+ reports: Mapped[list[Report]] = relationship(back_populates="location")
- def __init__(self, continent, city):
+ def __init__(self, continent: str, city: str):
self.continent = continent
self.city = city
class Report(Base):
__tablename__ = "_prefix__weather_reports"
- id = Column(Integer, primary_key=True)
- location_id = Column(
- "location_id", Integer, ForeignKey("_prefix__weather_locations.id")
+ id: Mapped[int] = mapped_column(primary_key=True)
+ location_id: Mapped[int] = mapped_column(
+ ForeignKey("_prefix__weather_locations.id")
)
- temperature = Column("temperature", Float)
- report_time = Column(
- "report_time", DateTime, default=datetime.datetime.now
+ temperature: Mapped[float]
+ report_time: Mapped[datetime.datetime] = mapped_column(
+ default=datetime.datetime.now
)
- def __init__(self, temperature):
- self.temperature = temperature
-
+ location: Mapped[WeatherLocation] = relationship(back_populates="reports")
-# create tables
-for db in (db1, db2, db3, db4):
- Base.metadata.create_all(db)
-
-# establish initial "id" in db1
-with db1.begin() as conn:
- conn.execute(ids.insert(), {"nextid": 1})
+ def __init__(self, temperature: float):
+ self.temperature = temperature
-# step 5. define sharding functions.
+# define sharding functions.
# we'll use a straight mapping of a particular set of "country"
# attributes to shard id.
execute_chooser=execute_chooser,
)
-# save and load objects!
-tokyo = WeatherLocation("Asia", "Tokyo")
-newyork = WeatherLocation("North America", "New York")
-toronto = WeatherLocation("North America", "Toronto")
-london = WeatherLocation("Europe", "London")
-dublin = WeatherLocation("Europe", "Dublin")
-brasilia = WeatherLocation("South America", "Brasila")
-quito = WeatherLocation("South America", "Quito")
+def setup():
+ # create tables
+ for db in (db1, db2, db3, db4):
+ Base.metadata.create_all(db)
-tokyo.reports.append(Report(80.0))
-newyork.reports.append(Report(75))
-quito.reports.append(Report(85))
+ # establish initial "id" in db1
+ with db1.begin() as conn:
+ conn.execute(ids.insert(), {"nextid": 1})
-with Session() as sess:
- sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
+def main():
+ setup()
- sess.commit()
+ # save and load objects!
- t = sess.get(WeatherLocation, tokyo.id)
- assert t.city == tokyo.city
- assert t.reports[0].temperature == 80.0
+ tokyo = WeatherLocation("Asia", "Tokyo")
+ newyork = WeatherLocation("North America", "New York")
+ toronto = WeatherLocation("North America", "Toronto")
+ london = WeatherLocation("Europe", "London")
+ dublin = WeatherLocation("Europe", "Dublin")
+ brasilia = WeatherLocation("South America", "Brasila")
+ quito = WeatherLocation("South America", "Quito")
- north_american_cities = sess.execute(
- select(WeatherLocation).filter(
- WeatherLocation.continent == "North America"
- )
- ).scalars()
+ tokyo.reports.append(Report(80.0))
+ newyork.reports.append(Report(75))
+ quito.reports.append(Report(85))
- assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
+ with Session() as sess:
- asia_and_europe = sess.execute(
- select(WeatherLocation).filter(
- WeatherLocation.continent.in_(["Europe", "Asia"])
+ sess.add_all(
+ [tokyo, newyork, toronto, london, dublin, brasilia, quito]
)
- ).scalars()
- assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"}
+ sess.commit()
- # the Report class uses a simple integer primary key. So across two
- # databases, a primary key will be repeated. The "identity_token" tracks
- # in memory that these two identical primary keys are local to different
- # databases.
- newyork_report = newyork.reports[0]
- tokyo_report = tokyo.reports[0]
+ t = sess.get(WeatherLocation, tokyo.id)
+ assert t.city == tokyo.city
+ assert t.reports[0].temperature == 80.0
- assert inspect(newyork_report).identity_key == (
- Report,
- (1,),
- "north_america",
- )
- assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+ # optionally set a shard id for the query and all related loaders
+ north_american_cities_w_t = sess.execute(
+ select(WeatherLocation)
+ .filter(WeatherLocation.city.startswith("T"))
+ .options(set_shard_id("north_america"))
+ ).scalars()
+
+ # Tokyo not included since not in the north_america shard
+ assert {c.city for c in north_american_cities_w_t} == {
+ "Toronto",
+ }
+
+ asia_and_europe = sess.execute(
+ select(WeatherLocation).filter(
+ WeatherLocation.continent.in_(["Europe", "Asia"])
+ )
+ ).scalars()
+
+ assert {c.city for c in asia_and_europe} == {
+ "Tokyo",
+ "London",
+ "Dublin",
+ }
+
+ # the Report class uses a simple integer primary key. So across two
+ # databases, a primary key will be repeated. The "identity_token"
+ # tracks in memory that these two identical primary keys are local to
+ # different shards.
+ newyork_report = newyork.reports[0]
+ tokyo_report = tokyo.reports[0]
+
+ assert inspect(newyork_report).identity_key == (
+ Report,
+ (1,),
+ "north_america",
+ )
+ assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
+
+ # the token representing the originating shard is also available
+ # directly
+ assert inspect(newyork_report).identity_token == "north_america"
+ assert inspect(tokyo_report).identity_token == "asia"
- # the token representing the originating shard is also available directly
- assert inspect(newyork_report).identity_token == "north_america"
- assert inspect(tokyo_report).identity_token == "asia"
+if __name__ == "__main__":
+ main()
from .. import util
from ..orm import PassiveFlag
from ..orm._typing import OrmExecuteOptionsParameter
+from ..orm.interfaces import ORMOption
from ..orm.mapper import Mapper
from ..orm.query import Query
from ..orm.session import _BindArguments
SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
-_ShardKey = str
+ShardIdentifier = str
class ShardChooser(Protocol):
.. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
:class:`.Query` class. The :class:`.ShardedSession` now supports
- 2.0 style execution via the :meth:`.ShardedSession.execute` method
- as well.
+ 2.0 style execution via the :meth:`.ShardedSession.execute` method.
"""
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
- def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
+ def set_shard(
+ self: SelfShardedQuery, shard_id: ShardIdentifier
+ ) -> SelfShardedQuery:
"""Return a new query, limited to a single shard ID.
All subsequent operations with the returned query will
should set whatever state on the instance to mark it in the future as
participating in that shard.
- :param id_chooser: A callable, passed a :class:`.ShardedQuery` and a
- tuple of identity values, which should return a list of shard ids
- where the ID might reside. The databases will be queried in the order
- of this listing.
+ :param identity_chooser: A callable, passed a Mapper and primary key
+ argument, which should return a list of shard ids where this
+ primary key might reside.
- .. legacy:: This parameter still uses the legacy
- :class:`.ShardedQuery` class as an argument passed to the
- callable.
+ .. versionchanged:: 2.0 The ``identity_chooser`` parameter
+ supersedes the ``id_chooser`` parameter.
:param execute_chooser: For a given :class:`.ORMExecuteState`,
returns the list of shard_ids
"execute_chooser or query_chooser is required"
)
self.execute_chooser = execute_chooser
- self.__shards: Dict[_ShardKey, _SessionBind] = {}
+ self.__shards: Dict[ShardIdentifier, _SessionBind] = {}
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
self,
mapper: Optional[Mapper[_T]] = None,
instance: Optional[Any] = None,
- shard_id: Optional[Any] = None,
+ shard_id: Optional[ShardIdentifier] = None,
**kw: Any,
) -> Connection:
"""Provide a :class:`_engine.Connection` to use in the unit of work
self,
mapper: Optional[_EntityBindKey[_O]] = None,
*,
- shard_id: Optional[_ShardKey] = None,
+ shard_id: Optional[ShardIdentifier] = None,
instance: Optional[Any] = None,
clause: Optional[ClauseElement] = None,
**kw: Any,
return self.__shards[shard_id]
def bind_shard(
- self, shard_id: _ShardKey, bind: Union[Engine, OptionEngine]
+ self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine]
) -> None:
self.__shards[shard_id] = bind
+class set_shard_id(ORMOption):
+ """a loader option for statements to apply a specific shard id to the
+ primary query as well as for additional relationship and column
+ loaders.
+
+ The :class:`_horizontal.set_shard_id` option may be applied using
+ the :meth:`_sql.Executable.options` method of any executable statement::
+
+ stmt = (
+ select(MyObject).
+ where(MyObject.name == 'some name').
+ options(set_shard_id("shard1"))
+ )
+
+ Above, the statement when invoked will limit to the "shard1" shard
+ identifier for the primary query as well as for all relationship and
+ column loading strategies, including eager loaders such as
+ :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`,
+ and the lazy relationship loader :func:`_orm.lazyload`.
+
+ In this way, the :class:`_horizontal.set_shard_id` option has much wider
+ scope than using the "shard_id" argument within the
+ :paramref:`_orm.Session.execute.bind_arguments` dictionary.
+
+
+ .. versionadded:: 2.0.0
+
+ """
+
+ __slots__ = ("shard_id", "propagate_to_loaders")
+
+ def __init__(
+ self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True
+ ):
+ """Construct a :class:`_horizontal.set_shard_id` option.
+
+ :param shard_id: shard identifier
+ :param propagate_to_loaders: if left at its default of ``True``, the
+ shard option will take place for lazy loaders such as
+ :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option
+ will not be propagated to loaded objects. Note that :func:`_orm.defer`
+ always limits to the shard_id of the parent row in any case, so the
+ parameter only has a net effect on the behavior of the
+ :func:`_orm.lazyload` strategy.
+
+ """
+ self.shard_id = shard_id
+ self.propagate_to_loaders = propagate_to_loaders
+
+
def execute_and_instances(
orm_context: ORMExecuteState,
) -> Union[Result[_T], IteratorResult[_TP]]:
assert isinstance(session, ShardedSession)
def iter_for_shard(
- shard_id: str,
+ shard_id: ShardIdentifier,
) -> Union[Result[_T], IteratorResult[_TP]]:
bind_arguments = dict(orm_context.bind_arguments)
orm_context.update_execution_options(identity_token=shard_id)
return orm_context.invoke_statement(bind_arguments=bind_arguments)
- if active_options and active_options._identity_token is not None:
- shard_id = active_options._identity_token
- elif "_sa_shard_id" in orm_context.execution_options:
- shard_id = orm_context.execution_options["_sa_shard_id"]
- elif "shard_id" in orm_context.bind_arguments:
- shard_id = orm_context.bind_arguments["shard_id"]
+ for orm_opt in orm_context._non_compile_orm_options:
+ # TODO: if we had an ORMOption that gets applied at ORM statement
+ # execution time, that would allow this to be more generalized.
+ # for now just iterate and look for our options
+ if isinstance(orm_opt, set_shard_id):
+ shard_id = orm_opt.shard_id
+ break
else:
- shard_id = None
+ if active_options and active_options._identity_token is not None:
+ shard_id = active_options._identity_token
+ elif "_sa_shard_id" in orm_context.execution_options:
+ shard_id = orm_context.execution_options["_sa_shard_id"]
+ elif "shard_id" in orm_context.bind_arguments:
+ shard_id = orm_context.bind_arguments["shard_id"]
+ else:
+ shard_id = None
if shard_id is not None:
return iter_for_shard(shard_id)
from .decl_api import registry as _registry_type
from .interfaces import InspectionAttr
from .interfaces import MapperProperty
+ from .interfaces import ORMOption
from .interfaces import UserDefinedOption
from .mapper import Mapper
from .relationships import RelationshipProperty
...
+def is_orm_option(
+ opt: ExecutableOption,
+) -> TypeGuard[ORMOption]:
+ return not opt._is_core # type: ignore
+
+
def is_user_defined_option(
opt: ExecutableOption,
) -> TypeGuard[UserDefinedOption]:
from ._typing import _O
from ._typing import insp_is_mapper
from ._typing import is_composite_class
+from ._typing import is_orm_option
from ._typing import is_user_defined_option
from .base import _class_to_mapper
from .base import _none_set
bulk_persistence.BulkUDCompileState.default_update_options,
)
+ @property
+ def _non_compile_orm_options(self) -> Sequence[ORMOption]:
+ return [
+ opt
+ for opt in self.statement._with_options
+ if is_orm_option(opt) and not opt._is_compile_state
+ ]
+
@property
def user_defined_options(self) -> Sequence[UserDefinedOption]:
"""The sequence of :class:`.UserDefinedOptions` that have been
db, callable_, assertsql.CountStatements(count)
)
- def assert_multiple_sql_count(self, dbs, callable_, counts):
+ @contextlib.contextmanager
+ def assert_execution(self, db, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ yield
+ asserter.assert_(*rules)
+
+ def assert_statement_count(self, db, count):
+ return self.assert_execution(db, assertsql.CountStatements(count))
+
+ @contextlib.contextmanager
+ def assert_statement_count_multi_db(self, dbs, counts):
recs = [
(self.sql_execution_asserter(db), db, count)
for (db, count) in zip(dbs, counts)
for ctx, db, count in recs:
asserters.append(ctx.__enter__())
try:
- return callable_()
+ yield
finally:
for asserter, (ctx, db, count) in zip(asserters, recs):
ctx.__exit__(None, None, None)
asserter.assert_(assertsql.CountStatements(count))
- @contextlib.contextmanager
- def assert_execution(self, db, *rules):
- with self.sql_execution_asserter(db) as asserter:
- yield
- asserter.assert_(*rules)
-
- def assert_statement_count(self, db, count):
- return self.assert_execution(db, assertsql.CountStatements(count))
-
class ComparesIndexes:
def compare_table_index_with_expected(
from sqlalchemy import text
from sqlalchemy import update
from sqlalchemy import util
+from sqlalchemy.ext.horizontal_shard import set_shard_id
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import defer
from sqlalchemy.orm import deferred
+from sqlalchemy.orm import lazyload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
def setup_test(self):
global db1, db2, db3, db4
- db1, db2, db3, db4 = self._dbs = self._init_dbs()
+ db1, db2, db3, db4 = self._dbs = self.dbs = self._init_dbs()
for db in (db1, db2, db3, db4):
self.tables_test_metadata.create_all(db)
t2 = sess.get(WeatherLocation, 1)
eq_(t2.city, "Tokyo")
+ @testing.variation("option_type", ["none", "lazyload", "selectinload"])
+ @testing.variation(
+ "limit_shard",
+ ["none", "lead_only", "propagate_to_loaders", "bind_arg"],
+ )
+ def test_set_shard_option_relationship(self, option_type, limit_shard):
+ sess = self._fixture_data()
+
+ stmt = select(WeatherLocation).filter(
+ WeatherLocation.city == "New York"
+ )
+
+ bind_arguments = {}
+
+ if limit_shard.none:
+ # right now selectinload / lazyload runs all the shards even if the
+ # ids are limited to just one shard, since that information
+ # is not transferred
+ counts = [2, 2, 2, 2]
+ elif limit_shard.lead_only:
+ if option_type.selectinload:
+ counts = [2, 0, 0, 0]
+ else:
+ counts = [2, 1, 1, 1]
+ elif limit_shard.bind_arg:
+ counts = [2, 1, 1, 1]
+ elif limit_shard.propagate_to_loaders:
+ counts = [2, 0, 0, 0]
+ else:
+ limit_shard.fail()
+
+ if option_type.lazyload:
+ stmt = stmt.options(lazyload(WeatherLocation.reports))
+ elif option_type.selectinload:
+ stmt = stmt.options(selectinload(WeatherLocation.reports))
+
+ if limit_shard.lead_only:
+ stmt = stmt.options(
+ set_shard_id("north_america", propagate_to_loaders=False)
+ )
+ elif limit_shard.propagate_to_loaders:
+ stmt = stmt.options(set_shard_id("north_america"))
+ elif limit_shard.bind_arg:
+ bind_arguments["shard_id"] = "north_america"
+
+ with self.assert_statement_count_multi_db(self.dbs, counts):
+ w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first()
+ w1.reports
+
+ @testing.variation("option_type", ["none", "defer"])
+ @testing.variation(
+ "limit_shard",
+ ["none", "lead_only", "propagate_to_loaders", "bind_arg"],
+ )
+ def test_set_shard_option_column(self, option_type, limit_shard):
+ sess = self._fixture_data()
+
+ stmt = select(WeatherLocation).filter(
+ WeatherLocation.city == "New York"
+ )
+
+ bind_arguments = {}
+
+ if limit_shard.none:
+ if option_type.defer:
+ counts = [2, 1, 1, 1]
+ else:
+ counts = [1, 1, 1, 1]
+ elif limit_shard.lead_only or limit_shard.propagate_to_loaders:
+ if option_type.defer:
+ counts = [2, 0, 0, 0]
+ else:
+ counts = [1, 0, 0, 0]
+ elif limit_shard.bind_arg:
+ if option_type.defer:
+ counts = [2, 0, 0, 0]
+ else:
+ counts = [1, 0, 0, 0]
+ else:
+ limit_shard.fail()
+
+ if option_type.defer:
+ stmt = stmt.options(defer(WeatherLocation.continent))
+
+ if limit_shard.lead_only:
+ stmt = stmt.options(
+ set_shard_id("north_america", propagate_to_loaders=False)
+ )
+ elif limit_shard.propagate_to_loaders:
+ stmt = stmt.options(set_shard_id("north_america"))
+ elif limit_shard.bind_arg:
+ bind_arguments["shard_id"] = "north_america"
+
+ with self.assert_statement_count_multi_db(self.dbs, counts):
+ w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first()
+ w1.continent
+
def test_query_explicit_shard_via_bind_opts(self):
sess = self._fixture_data()
.execution_options(synchronize_session=synchronize_session)
)
- # test synchronize session
- def go():
+ with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]):
eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0})
- self.assert_sql_count(
- sess._ShardedSession__shards["north_america"], go, 0
- )
-
eq_(
{row.temperature for row in sess.query(Report.temperature)},
{86.0, 75.0, 91.0},
.execution_options(synchronize_session=synchronize_session)
)
- def go():
+ with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]):
# test synchronize session
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
- self.assert_sql_count(
- sess._ShardedSession__shards["north_america"], go, 0
- )
-
eq_(
{row.temperature for row in sess.query(Report.temperature)},
{75.0},
schema = "changeme"
def _init_dbs(self):
- db1 = testing_engine(
- "sqlite://", options={"execution_options": {"shard_id": "shard1"}}
- )
- db2 = db1.execution_options(shard_id="shard2")
- db3 = db1.execution_options(shard_id="shard3")
- db4 = db1.execution_options(shard_id="shard4")
+ dbmain = testing_engine("sqlite://")
+ db1 = dbmain.execution_options(shard_id="shard1")
+ db2 = dbmain.execution_options(shard_id="shard2")
+ db3 = dbmain.execution_options(shard_id="shard3")
+ db4 = dbmain.execution_options(shard_id="shard4")
import re
- @event.listens_for(db1, "before_cursor_execute", retval=True)
+ @event.listens_for(dbmain, "before_cursor_execute", retval=True)
def _switch_shard(conn, cursor, stmt, params, context, executemany):
shard_id = conn._execution_options["shard_id"]
# because SQLite can't just give us a "use" statement, we have
session.expire(page, ["book"])
- def go():
+ with self.assert_statement_count_multi_db(self.dbs, [0, 0]):
+ # doesn't emit SQL
eq_(page.book, book)
- # doesn't emit SQL
- self.assert_multiple_sql_count(self.dbs, go, [0, 0])
-
def test_lazy_load_from_db(self):
session = self._fixture(lazy_load_book=True)
book1_page = session.query(Page).first()
session.expire(book1_page, ["book"])
- def go():
+ with self.assert_statement_count_multi_db(self.dbs, [1, 0]):
+ # emits one query
eq_(inspect(book1_page.book).identity_key, book1_id)
- # emits one query
- self.assert_multiple_sql_count(self.dbs, go, [1, 0])
-
def test_lazy_load_no_baked_conflict(self):
session = self._fixture(lazy_load_pages=True)